11import numpy as np
22import keras
3- from tests .utils import check_combination_simulator_adapter , check_approximator_multivariate_normal_score
3+ from tests .utils import check_combination_simulator_adapter
44
55
66def test_approximator_sample (approximator , simulator , batch_size , adapter ):
77 check_combination_simulator_adapter (simulator , adapter )
8- # as long as MultivariateNormalScore is unstable, skip
9- check_approximator_multivariate_normal_score (approximator )
108
119 num_batches = 4
1210 data = simulator .sample ((num_batches * batch_size ,))
@@ -23,8 +21,6 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
2321
2422def test_approximator_sample_keep_conditions (approximator , simulator , batch_size , adapter ):
2523 check_combination_simulator_adapter (simulator , adapter )
26- # as long as MultivariateNormalScore is unstable, skip
27- check_approximator_multivariate_normal_score (approximator )
2824
2925 num_batches = 4
3026 data = simulator .sample ((num_batches * batch_size ,))
@@ -39,14 +35,18 @@ def test_approximator_sample_keep_conditions(approximator, simulator, batch_size
3935
4036 assert isinstance (samples_and_conditions , dict )
4137
42- adapted_samples_and_conditions = adapter (samples_and_conditions , strict = False )
38+ # remove inference_variables from sample output and apply adapter
39+ inference_variables_keys = approximator .sample (num_samples = num_samples , conditions = data ).keys ()
40+ for key in inference_variables_keys :
41+ samples_and_conditions .pop (key )
42+ adapted_conditions = adapter (samples_and_conditions , strict = False )
4343
44- assert any (k in adapted_samples_and_conditions for k in approximator .CONDITION_KEYS ), (
44+ assert any (k in adapted_conditions for k in approximator .CONDITION_KEYS ), (
4545 f"adapter(approximator.sample(..., keep_conditions=True)) must return at least one of"
46- f"{ approximator .CONDITION_KEYS !r} . Keys are { adapted_samples_and_conditions .keys ()} ."
46+ f"{ approximator .CONDITION_KEYS !r} . Keys are { adapted_conditions .keys ()} ."
4747 )
4848
49- for key , value in adapted_samples_and_conditions .items ():
49+ for key , value in adapted_conditions .items ():
5050 assert value .shape [:2 ] == (num_batches * batch_size , num_samples ), (
5151 f"{ key } should have shape ({ num_batches * batch_size } , { num_samples } , ...) but has { value .shape } ."
5252 )
0 commit comments