Skip to content

Commit f71aedc

Browse files
author
Goose
committed
revert changes to tests
1 parent 1fb9df1 commit f71aedc

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tests/sampling/test_jax.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,22 +337,18 @@ def test_get_batched_jittered_initial_points():
337337

338338
# No jitter
339339
ips = _get_batched_jittered_initial_points(
340-
model=model, chains=1, random_seed=1, initvals=None, jitter=False, logp_fn=logp_fn
340+
model=model, chains=1, random_seed=1, initvals=None, jitter=False
341341
)
342342
assert np.all(ips[0] == 0)
343343

344344
# Single chain
345-
ips = _get_batched_jittered_initial_points(
346-
model=model, chains=1, random_seed=1, initvals=None, logp_fn=logp_fn
347-
)
345+
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
348346

349347
assert ips[0].shape == (2, 3)
350348
assert np.all(ips[0] != 0)
351349

352350
# Multiple chains
353-
ips = _get_batched_jittered_initial_points(
354-
model=model, chains=2, random_seed=1, initvals=None, logp_fn=logp_fn
355-
)
351+
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
356352

357353
assert ips[0].shape == (2, 2, 3)
358354
assert np.all(ips[0][0] != ips[0][1])

0 commit comments

Comments
 (0)