Skip to content

Commit 7d01fef

Browse files
committed
fix(stan): rng for generated quantities
Generated quantities have their own random number generator. The seed for this generator did not depend on the global seed of the model, so that random stream for two different sampler runs were reusing the same randomness.
1 parent 8f15f8e commit 7d01fef

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

src/stan.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,8 @@ impl Model for StanModel {
531531

532532
fn new_trace<'a, S: Settings, R: rand::Rng + ?Sized>(
533533
&'a self,
534-
_rng: &mut R,
535-
chain: u64,
534+
rng: &mut R,
535+
_chain: u64,
536536
settings: &S,
537537
) -> anyhow::Result<Self::DrawStorage<'a, S>> {
538538
let draws = settings.hint_num_tune() + settings.hint_num_draws();
@@ -541,7 +541,8 @@ impl Model for StanModel {
541541
.iter()
542542
.map(|var| Vec::with_capacity(var.size * draws))
543543
.collect();
544-
let rng = self.model.new_rng(chain as u32)?;
544+
let seed = rng.next_u32();
545+
let rng = self.model.new_rng(seed)?;
545546
let buffer = vec![0f64; self.model.param_num(true, true)];
546547
Ok(StanTrace {
547548
model: self,

tests/test_stan.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,43 @@ def test_stan_model():
2727
trace.posterior.a # noqa: B018
2828

2929

30+
@pytest.mark.stan
31+
def test_seed():
32+
model = """
33+
data {}
34+
parameters {
35+
real a;
36+
}
37+
model {
38+
a ~ normal(0, 1);
39+
}
40+
generated quantities {
41+
real b = normal_rng(0, 1);
42+
}
43+
"""
44+
45+
compiled_model = nutpie.compile_stan_model(code=model)
46+
trace = nutpie.sample(compiled_model, seed=42)
47+
trace2 = nutpie.sample(compiled_model, seed=42)
48+
trace3 = nutpie.sample(compiled_model, seed=43)
49+
50+
assert np.allclose(trace.posterior.a, trace2.posterior.a)
51+
assert np.allclose(trace.posterior.b, trace2.posterior.b)
52+
53+
assert not np.allclose(trace.posterior.a, trace3.posterior.a)
54+
assert not np.allclose(trace.posterior.b, trace3.posterior.b)
55+
# Check that all chains are pairwise different
56+
for i in range(len(trace.posterior.a)):
57+
for j in range(i + 1, len(trace.posterior.a)):
58+
assert not np.allclose(trace.posterior.a[i], trace.posterior.a[j])
59+
assert not np.allclose(trace.posterior.b[i], trace.posterior.b[j])
60+
# Check that all chains are pairwise different between seeds
61+
for i in range(len(trace.posterior.a)):
62+
for j in range(len(trace3.posterior.a)):
63+
assert not np.allclose(trace.posterior.a[i], trace3.posterior.a[j])
64+
assert not np.allclose(trace.posterior.b[i], trace3.posterior.b[j])
65+
66+
3067
@pytest.mark.stan
3168
def test_stan_model_data():
3269
model = """

0 commit comments

Comments
 (0)