Skip to content

Commit 97f7a72

Browse files
committed
improve adapter.constrain tests
1 parent 5befe35 commit 97f7a72

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed

tests/test_adapters/test_adapters.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,56 @@ def test_serialize_deserialize(adapter, custom_objects, random_data):
3434

3535

3636
def test_constrain():
37+
# check if constraint-implied transforms are applied correctly
3738
import numpy as np
3839
import warnings
3940
from bayesflow.adapters import Adapter
4041

4142
data = {
42-
"x1": np.random.exponential(1, size=(32, 1)),
43-
"x2": -np.random.exponential(1, size=(32, 1)),
44-
"x3": np.random.beta(0.5, 0.5, size=(32, 1)),
45-
"x4": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
46-
"x5": np.zeros(shape=(32, 1)),
47-
"x6": np.zeros(shape=(32, 1)),
43+
"x_lower_cont": np.random.exponential(1, size=(32, 1)),
44+
"x_upper_cont": -np.random.exponential(1, size=(32, 1)),
45+
"x_both_cont": np.random.beta(0.5, 0.5, size=(32, 1)),
46+
"x_lower_disc1": np.zeros(shape=(32, 1)),
47+
"x_lower_disc2": np.zeros(shape=(32, 1)),
48+
"x_upper_disc1": np.ones(shape=(32, 1)),
49+
"x_upper_disc2": np.ones(shape=(32, 1)),
50+
"x_both_disc1": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
51+
"x_both_disc2": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
4852
}
4953

5054
adapter = (
5155
Adapter()
52-
.constrain("x1", lower=0)
53-
.constrain("x2", upper=0)
54-
.constrain("x3", lower=0, upper=1)
55-
.constrain("x4", lower=0, upper=1, inclusive="both")
56-
.constrain("x5", lower=0, inclusive="none")
57-
.constrain("x6", upper=0, inclusive="none")
56+
.constrain("x_lower_cont", lower=0)
57+
.constrain("x_upper_cont", upper=0)
58+
.constrain("x_both_cont", lower=0, upper=1)
59+
.constrain("x_lower_disc1", lower=0, inclusive="lower")
60+
.constrain("x_lower_disc2", lower=0, inclusive="none")
61+
.constrain("x_upper_disc1", upper=1, inclusive="upper")
62+
.constrain("x_upper_disc2", upper=1, inclusive="none")
63+
.constrain("x_both_disc1", lower=0, upper=1, inclusive="both")
64+
.constrain("x_both_disc2", lower=0, upper=1, inclusive="none")
5865
)
5966

6067
with warnings.catch_warnings():
6168
warnings.simplefilter("ignore", RuntimeWarning)
6269
result = adapter(data)
6370

64-
# checks if transformations indeed have been applied
65-
assert result["x1"].min() < 0.0
66-
assert result["x2"].max() > 0.0
67-
assert result["x3"].min() < 0.0
68-
assert result["x3"].max() > 1.0
69-
assert np.isfinite(result["x4"].min())
70-
assert np.isfinite(result["x4"].max())
71-
assert np.isneginf(result["x5"][0])
72-
assert np.isinf(result["x6"][0])
71+
# continuous variables should not have boundary issues
72+
assert result["x_lower_cont"].min() < 0.0
73+
assert result["x_upper_cont"].max() > 0.0
74+
assert result["x_both_cont"].min() < 0.0
75+
assert result["x_both_cont"].max() > 1.0
76+
77+
# discrete variables at the boundaries should not have issues
78+
# if inclusive is set properly
79+
assert np.isfinite(result["x_lower_disc1"].min())
80+
assert np.isfinite(result["x_upper_disc1"].max())
81+
assert np.isfinite(result["x_both_disc1"].min())
82+
assert np.isfinite(result["x_both_disc1"].max())
83+
84+
# discrete variables at the boundaries should have issues
85+
# if inclusive is not set properly
86+
assert np.isneginf(result["x_lower_disc2"][0])
87+
assert np.isinf(result["x_upper_disc2"][0])
88+
assert np.isneginf(result["x_both_disc2"][0])
89+
assert np.isinf(result["x_both_disc2"][-1])

0 commit comments

Comments
 (0)