@@ -121,14 +121,19 @@ def test_sample_var_names(nuts_sampler):
121121 sigma = HalfNormal ("sigma" )
122122 Normal ("y" , mu = mu , sigma = sigma , observed = y )
123123
124- # Sample with and without var_names, but always with the same seed
124+ free_RVs = [var .name for var in model .free_RVs ]
125+
125126 with model :
127+ # Sample with and without var_names, but always with the same seed
126128 idata_1 = sample (** kwargs )
127- idata_2 = sample (var_names = ["b_group" , "b_x" , "sigma" ], ** kwargs )
129+ # Remove the last free RV from the sampling
130+ idata_2 = sample (var_names = free_RVs [:- 1 ], ** kwargs )
128131
129132 assert "mu" in idata_1 .posterior
130133 assert "mu" not in idata_2 .posterior
131134
132- xr .testing .assert_allclose (idata_1 .posterior ["b_group" ], idata_2 .posterior ["b_group" ])
133- xr .testing .assert_allclose (idata_1 .posterior ["b_x" ], idata_2 .posterior ["b_x" ])
134- xr .testing .assert_allclose (idata_1 .posterior ["sigma" ], idata_2 .posterior ["sigma" ])
135+ for var in free_RVs [:- 1 ]:
136+ assert var in idata_1 .posterior
137+ assert var in idata_2 .posterior
138+
139+ xr .testing .assert_allclose (idata_1 .posterior [var ], idata_2 .posterior [var ])
0 commit comments