@@ -25,16 +25,33 @@ def test_data(bernoulli_model):
2525 assert 0.2 < out2 ["theta" ].mean () < 0.3
2626
2727
28- def test_seed (bernoulli_model ):
29- out1 = bernoulli_model .pathfinder (BERNOULLI_DATA , seed = 123 )
30- out2 = bernoulli_model .pathfinder (BERNOULLI_DATA , seed = 123 )
31-
32- np .testing .assert_equal (out1 ["theta" ], out2 ["theta" ])
28+ @pytest .mark .parametrize ("num_paths" , [1 , 4 ])
29+ @pytest .mark .parametrize ("psis_resample" , [True , False ])
30+ def test_seed (bernoulli_model , num_paths , psis_resample ):
31+ out1 = bernoulli_model .pathfinder (
32+ BERNOULLI_DATA ,
33+ seed = 123 ,
34+ num_paths = num_paths ,
35+ psis_resample = psis_resample ,
36+ )
37+ out2 = bernoulli_model .pathfinder (
38+ BERNOULLI_DATA ,
39+ seed = 123 ,
40+ num_paths = num_paths ,
41+ psis_resample = psis_resample ,
42+ )
43+ assert out1 .data .shape [1 ] == 5
44+ np .testing .assert_equal (np .sort (out1 .data , axis = 0 ), np .sort (out2 .data , axis = 0 ))
3345
34- out3 = bernoulli_model .pathfinder (BERNOULLI_DATA , seed = 456 )
46+ out3 = bernoulli_model .pathfinder (
47+ BERNOULLI_DATA ,
48+ seed = 456 ,
49+ num_paths = num_paths ,
50+ psis_resample = psis_resample ,
51+ )
3552
3653 with pytest .raises (AssertionError ):
37- np .testing .assert_equal (out1 ["theta" ], out3 ["theta" ])
54+ np .testing .assert_equal (np . sort ( out1 ["theta" ]), np . sort ( out3 ["theta" ]) )
3855
3956
4057def test_output_sizes (bernoulli_model ):
@@ -133,20 +150,18 @@ def test_inits(multimodal_model, temp_json):
133150 assert np .all (out3 ["mu" ] < 0 )
134151
135152
136- def test_inits_mode (multimodal_model ):
153+ @pytest .mark .parametrize ("num_paths" , [1 , 4 ])
154+ @pytest .mark .parametrize ("psis_resample" , [True , False ])
155+ def test_inits_mode (multimodal_model , num_paths , psis_resample ):
137156 # initializing at mode means theres nowhere to go
138157 init1 = {"mu" : - 100 }
139158
140159 with pytest .raises (
141160 RuntimeError , match = "None of the LBFGS iterations completed successfully"
142161 ):
143- multimodal_model .pathfinder (inits = init1 )
144-
145- # also for single path
146- with pytest .raises (
147- RuntimeError , match = "None of the LBFGS iterations completed successfully"
148- ):
149- multimodal_model .pathfinder (inits = init1 , num_paths = 1 , psis_resample = False )
162+ multimodal_model .pathfinder (
163+ inits = init1 , num_paths = num_paths , psis_resample = psis_resample
164+ )
150165
151166
152167def test_bad_data (bernoulli_model ):
0 commit comments