Skip to content

Commit 72180e3

Browse files
committed
Test approximator.sample with keep_conditions
1 parent c37e2e6 commit 72180e3

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/test_approximators/test_sample.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import keras
23
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
34

@@ -18,3 +19,37 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
1819
samples = approximator.sample(num_samples=2, conditions=data)
1920

2021
assert isinstance(samples, dict)
22+
23+
24+
def test_approximator_sample_keep_conditions(approximator, simulator, batch_size, adapter):
25+
check_combination_simulator_adapter(simulator, adapter)
26+
# as long as MultivariateNormalScore is unstable, skip
27+
check_approximator_multivariate_normal_score(approximator)
28+
29+
num_batches = 4
30+
data = simulator.sample((num_batches * batch_size,))
31+
32+
batch = adapter(data)
33+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
34+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
35+
approximator.build(batch_shapes)
36+
37+
num_samples = 2
38+
samples_and_conditions = approximator.sample(num_samples=num_samples, conditions=data, keep_conditions=True)
39+
40+
assert isinstance(samples_and_conditions, dict)
41+
42+
adapted_samples_and_conditions = adapter(samples_and_conditions, strict=False)
43+
44+
assert any(k in adapted_samples_and_conditions for k in approximator.CONDITION_KEYS), (
45+
f"adapter(approximator.sample(..., keep_conditions=True)) must return at least one of"
46+
f"{approximator.CONDITION_KEYS!r}. Keys are {adapted_samples_and_conditions.keys()}."
47+
)
48+
49+
for key, value in adapted_samples_and_conditions.items():
50+
assert value.shape[:2] == (num_batches * batch_size, num_samples), (
51+
f"{key} should have shape ({num_batches * batch_size}, {num_samples}, ...) but has {value.shape}."
52+
)
53+
54+
if key in approximator.CONDITION_KEYS:
55+
assert np.all(np.ptp(value, axis=1) == 0), "Not all values are the same along axis 1"

0 commit comments

Comments
 (0)