Skip to content

Commit 08d5474

Browse files
committed
Remove skip instructions across approximator tests regarding MVNScore
1 parent a5ce7e7 commit 08d5474

File tree

3 files changed

+10
-16
lines changed

3 files changed

+10
-16
lines changed

tests/test_approximators/test_fit.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
import io
55
from contextlib import redirect_stdout
6-
from tests.utils import check_approximator_multivariate_normal_score
76

87

98
@pytest.mark.skip(reason="not implemented")
@@ -20,9 +19,6 @@ def test_fit(amortizer, dataset):
2019

2120

2221
def test_loss_progress(approximator, train_dataset, validation_dataset):
23-
# as long as MultivariateNormalScore is unstable, skip fit progress test
24-
check_approximator_multivariate_normal_score(approximator)
25-
2622
approximator.compile(optimizer="AdamW")
2723
num_epochs = 3
2824

tests/test_approximators/test_log_prob.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import keras
22
import numpy as np
3-
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
3+
from tests.utils import check_combination_simulator_adapter
44

55

66
def test_approximator_log_prob(approximator, simulator, batch_size, adapter):
77
check_combination_simulator_adapter(simulator, adapter)
8-
# as long as MultivariateNormalScore is unstable, skip
9-
check_approximator_multivariate_normal_score(approximator)
108

119
num_batches = 4
1210
data = simulator.sample((num_batches * batch_size,))

tests/test_approximators/test_sample.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import numpy as np
22
import keras
3-
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
3+
from tests.utils import check_combination_simulator_adapter
44

55

66
def test_approximator_sample(approximator, simulator, batch_size, adapter):
77
check_combination_simulator_adapter(simulator, adapter)
8-
# as long as MultivariateNormalScore is unstable, skip
9-
check_approximator_multivariate_normal_score(approximator)
108

119
num_batches = 4
1210
data = simulator.sample((num_batches * batch_size,))
@@ -23,8 +21,6 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
2321

2422
def test_approximator_sample_keep_conditions(approximator, simulator, batch_size, adapter):
2523
check_combination_simulator_adapter(simulator, adapter)
26-
# as long as MultivariateNormalScore is unstable, skip
27-
check_approximator_multivariate_normal_score(approximator)
2824

2925
num_batches = 4
3026
data = simulator.sample((num_batches * batch_size,))
@@ -39,14 +35,18 @@ def test_approximator_sample_keep_conditions(approximator, simulator, batch_size
3935

4036
assert isinstance(samples_and_conditions, dict)
4137

42-
adapted_samples_and_conditions = adapter(samples_and_conditions, strict=False)
38+
# remove inference_variables from sample output and apply adapter
39+
inference_variables_keys = approximator.sample(num_samples=num_samples, conditions=data).keys()
40+
for key in inference_variables_keys:
41+
samples_and_conditions.pop(key)
42+
adapted_conditions = adapter(samples_and_conditions, strict=False)
4343

44-
assert any(k in adapted_samples_and_conditions for k in approximator.CONDITION_KEYS), (
44+
assert any(k in adapted_conditions for k in approximator.CONDITION_KEYS), (
4545
f"adapter(approximator.sample(..., keep_conditions=True)) must return at least one of"
46-
f"{approximator.CONDITION_KEYS!r}. Keys are {adapted_samples_and_conditions.keys()}."
46+
f"{approximator.CONDITION_KEYS!r}. Keys are {adapted_conditions.keys()}."
4747
)
4848

49-
for key, value in adapted_samples_and_conditions.items():
49+
for key, value in adapted_conditions.items():
5050
assert value.shape[:2] == (num_batches * batch_size, num_samples), (
5151
f"{key} should have shape ({num_batches * batch_size}, {num_samples}, ...) but has {value.shape}."
5252
)

0 commit comments

Comments
 (0)