Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit fe5190e

Browse files
Make terms in sampling tests more identifiable
1 parent 295d816 commit fe5190e

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

tests/test_gibbs.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_nbinom_normal_posterior(srng):
168168
M = 10
169169
N = 50
170170

171-
true_h = 10
171+
true_h = 100
172172
true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.0] * (M - 5))
173173
S = toeplitz(0.5 ** np.arange(M))
174174
X_at = srng.multivariate_normal(np.zeros(M), cov=S, size=N)
@@ -189,40 +189,41 @@ def test_nbinom_normal_posterior(srng):
189189
beta_post_vals += [beta_post_val]
190190

191191
beta_post_mean = np.mean(beta_post_vals, axis=0)
192-
assert np.allclose(beta_post_mean, true_beta, atol=3e-1)
192+
assert np.allclose(beta_post_mean, true_beta, atol=1e-1)
193193

194194

195195
def test_nbinom_dispersion_posterior(srng):
196196
M = 10
197-
N = 50
197+
N = 100
198198

199-
true_h = 10
199+
true_h = 100
200200
true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.1] * (M - 5))
201201
S = toeplitz(0.5 ** np.arange(M))
202202
X = srng.multivariate_normal(np.zeros(M), cov=S, size=N)
203203
p_at = at.sigmoid(-(X.dot(true_beta)))
204204
p, y = aesara.function([], [p_at, srng.nbinom(true_h, p_at)])()
205205

206-
a_val = 1.0
207-
b_val = 2e-1
206+
a_val = 70.0
207+
b_val = 0.9
208208
a = at.as_tensor(a_val)
209209
b = at.as_tensor(b_val)
210210

211211
h_samples, h_updates = aesara.scan(
212-
lambda: nbinom_dispersion_posterior(srng, at.as_tensor(true_h), p, a, b, y),
212+
lambda last_h: nbinom_dispersion_posterior(srng, last_h, p, a, b, y),
213+
outputs_info=[at.as_tensor(90.0, dtype=np.float64)],
213214
n_steps=1000,
214215
)
215216

216217
h_mean_fn = aesara.function([], h_samples.mean(), updates=h_updates)
217218

218219
h_mean_val = h_mean_fn()
219220

220-
# Make sure that the posterior `h` values aren't right around the prior
221-
# mean
222-
assert not np.allclose(h_mean_val, a_val / b_val, rtol=2e-1)
221+
# Make sure that the average posterior `h` values have increased relative
222+
# to the prior mean
223+
assert h_mean_val > a_val / b_val
223224

224225
# Make sure the posterior values are near the "true" value
225-
assert np.allclose(h_mean_val, true_h, rtol=2e-1)
226+
assert np.allclose(h_mean_val, true_h, rtol=1e-1)
226227

227228

228229
def test_bern_sigmoid_dot_match(srng):
@@ -256,9 +257,9 @@ def test_bern_sigmoid_dot_match(srng):
256257

257258
def test_bern_normal_posterior(srng):
258259
M = 10
259-
N = 50
260+
N = 100
260261

261-
true_beta = np.array([2, 0.02, 0.2, 0.1, 1] + [0.1] * (M - 5))
262+
true_beta = np.array([3, 2, 1, 0.5, 0.05] + [0.0] * (M - 5))
262263
S = toeplitz(0.5 ** np.arange(M))
263264
X_at = srng.multivariate_normal(np.zeros(M), cov=S, size=N)
264265
p_at = at.sigmoid(X_at.dot(true_beta))
@@ -271,12 +272,12 @@ def test_bern_normal_posterior(srng):
271272

272273
beta_post_vals = []
273274
beta_post_val = np.zeros(M)
274-
for i in range(3000):
275+
for i in range(1000):
275276
beta_post_val = beta_post_fn(beta_post_val)
276277
beta_post_vals += [beta_post_val]
277278

278279
beta_post_mean = np.mean(beta_post_vals, axis=0)
279-
assert np.allclose(beta_post_mean, true_beta, atol=0.7)
280+
assert np.allclose(beta_post_mean, true_beta, atol=0.5)
280281

281282

282283
def test_gamma_match(srng):

0 commit comments

Comments
 (0)