1+ import numpy as np
12import keras
23from tests .utils import check_combination_simulator_adapter , check_approximator_multivariate_normal_score
34
@@ -18,3 +19,37 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
1819 samples = approximator .sample (num_samples = 2 , conditions = data )
1920
2021 assert isinstance (samples , dict )
22+
23+
24+ def test_approximator_sample_keep_conditions (approximator , simulator , batch_size , adapter ):
25+ check_combination_simulator_adapter (simulator , adapter )
26+ # as long as MultivariateNormalScore is unstable, skip
27+ check_approximator_multivariate_normal_score (approximator )
28+
29+ num_batches = 4
30+ data = simulator .sample ((num_batches * batch_size ,))
31+
32+ batch = adapter (data )
33+ batch = keras .tree .map_structure (keras .ops .convert_to_tensor , batch )
34+ batch_shapes = keras .tree .map_structure (keras .ops .shape , batch )
35+ approximator .build (batch_shapes )
36+
37+ num_samples = 2
38+ samples_and_conditions = approximator .sample (num_samples = num_samples , conditions = data , keep_conditions = True )
39+
40+ assert isinstance (samples_and_conditions , dict )
41+
42+ adapted_samples_and_conditions = adapter (samples_and_conditions , strict = False )
43+
44+ assert any (k in adapted_samples_and_conditions for k in approximator .CONDITION_KEYS ), (
45+ f"adapter(approximator.sample(..., keep_conditions=True)) must return at least one of"
46+ f"{ approximator .CONDITION_KEYS !r} . Keys are { adapted_samples_and_conditions .keys ()} ."
47+ )
48+
49+ for key , value in adapted_samples_and_conditions .items ():
50+ assert value .shape [:2 ] == (num_batches * batch_size , num_samples ), (
51+ f"{ key } should have shape ({ num_batches * batch_size } , { num_samples } , ...) but has { value .shape } ."
52+ )
53+
54+ if key in approximator .CONDITION_KEYS :
55+ assert np .all (np .ptp (value , axis = 1 ) == 0 ), "Not all values are the same along axis 1"
0 commit comments