Skip to content

Commit 5befe35

Browse files
committed
add tests for adapter.constrain
1 parent e280e49 commit 5befe35

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/test_adapters/test_adapters.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,42 @@ def test_serialize_deserialize(adapter, custom_objects, random_data):
3131
deserialized_processed = deserialized(random_data)
3232
for key, value in processed.items():
3333
assert np.allclose(value, deserialized_processed[key])
34+
35+
36+
def test_constrain():
37+
import numpy as np
38+
import warnings
39+
from bayesflow.adapters import Adapter
40+
41+
data = {
42+
"x1": np.random.exponential(1, size=(32, 1)),
43+
"x2": -np.random.exponential(1, size=(32, 1)),
44+
"x3": np.random.beta(0.5, 0.5, size=(32, 1)),
45+
"x4": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
46+
"x5": np.zeros(shape=(32, 1)),
47+
"x6": np.zeros(shape=(32, 1)),
48+
}
49+
50+
adapter = (
51+
Adapter()
52+
.constrain("x1", lower=0)
53+
.constrain("x2", upper=0)
54+
.constrain("x3", lower=0, upper=1)
55+
.constrain("x4", lower=0, upper=1, inclusive="both")
56+
.constrain("x5", lower=0, inclusive="none")
57+
.constrain("x6", upper=0, inclusive="none")
58+
)
59+
60+
with warnings.catch_warnings():
61+
warnings.simplefilter("ignore", RuntimeWarning)
62+
result = adapter(data)
63+
64+
# checks if transformations indeed have been applied
65+
assert result["x1"].min() < 0.0
66+
assert result["x2"].max() > 0.0
67+
assert result["x3"].min() < 0.0
68+
assert result["x3"].max() > 1.0
69+
assert np.isfinite(result["x4"].min())
70+
assert np.isfinite(result["x4"].max())
71+
assert np.isneginf(result["x5"][0])
72+
assert np.isinf(result["x6"][0])

0 commit comments

Comments
 (0)