@@ -337,22 +337,18 @@ def test_get_batched_jittered_initial_points():
337337
338338 # No jitter
339339 ips = _get_batched_jittered_initial_points (
340- model = model , chains = 1 , random_seed = 1 , initvals = None , jitter = False , logp_fn = logp_fn
340+ model = model , chains = 1 , random_seed = 1 , initvals = None , jitter = False
341341 )
342342 assert np .all (ips [0 ] == 0 )
343343
344344 # Single chain
345- ips = _get_batched_jittered_initial_points (
346- model = model , chains = 1 , random_seed = 1 , initvals = None , logp_fn = logp_fn
347- )
345+ ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
348346
349347 assert ips [0 ].shape == (2 , 3 )
350348 assert np .all (ips [0 ] != 0 )
351349
352350 # Multiple chains
353- ips = _get_batched_jittered_initial_points (
354- model = model , chains = 2 , random_seed = 1 , initvals = None , logp_fn = logp_fn
355- )
351+ ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
356352
357353 assert ips [0 ].shape == (2 , 2 , 3 )
358354 assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
0 commit comments