Skip to content

Commit a822d01

Browse files
Merge pull request #338 from bayesflow-org/allow-inclusive-bounds
Allow inclusive bounds in `Adapter.constrain()`
2 parents 5bf9a54 + 97f7a72 commit a822d01

File tree

3 files changed

+104
-10
lines changed

3 files changed

+104
-10
lines changed

bayesflow/adapters/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,17 @@ def constrain(
233233
lower: int | float | np.ndarray = None,
234234
upper: int | float | np.ndarray = None,
235235
method: str = "default",
236+
inclusive: str = "both",
237+
epsilon: float = 1e-15,
236238
):
237239
if isinstance(keys, str):
238240
keys = [keys]
239241

240242
transform = MapTransform(
241-
transform_map={key: Constrain(lower=lower, upper=upper, method=method) for key in keys}
243+
transform_map={
244+
key: Constrain(lower=lower, upper=upper, method=method, inclusive=inclusive, epsilon=epsilon)
245+
for key in keys
246+
}
242247
)
243248
self.transforms.append(transform)
244249
return self

bayesflow/adapters/transforms/constrain.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@serializable(package="bayesflow.adapters")
1717
class Constrain(ElementwiseTransform):
1818
"""
19-
Constrains neural network predictions of a data variable to specificied bounds.
19+
Constrains neural network predictions of a data variable to specified bounds.
2020
2121
Parameters:
2222
String containing the name of the data variable to be transformed e.g. "sigma". See examples below.
@@ -28,14 +28,21 @@ class Constrain(ElementwiseTransform):
2828
- Double bounded methods: sigmoid, expit, (default = sigmoid)
2929
- Lower bound only methods: softplus, exp, (default = softplus)
3030
- Upper bound only methods: softplus, exp, (default = softplus)
31-
31+
inclusive: Indicates which bounds are inclusive (or exclusive).
32+
- "both" (default): Both lower and upper bounds are inclusive.
33+
- "lower": Lower bound is inclusive, upper bound is exclusive.
34+
- "upper": Lower bound is exclusive, upper bound is inclusive.
35+
- "none": Both lower and upper bounds are exclusive.
36+
epsilon: Small value to ensure inclusive bounds are not violated.
37+
Current default is 1e-15 as this ensures finite outcomes
38+
with the default transformations applied to data exactly at the boundaries.
3239
3340
3441
Examples:
3542
1) Let sigma be the standard deviation of a normal distribution,
3643
then sigma should always be greater than zero.
3744
38-
Useage:
45+
Usage:
3946
adapter = (
4047
bf.Adapter()
4148
.constrain("sigma", lower=0)
@@ -45,14 +52,19 @@ class Constrain(ElementwiseTransform):
4552
[0,1] then we would constrain the neural network to estimate p in the following way.
4653
4754
Usage:
48-
adapter = (
49-
bf.Adapter()
50-
.constrain("p", lower=0, upper=1, method = "sigmoid")
51-
)
55+
>>> import bayesflow as bf
56+
>>> adapter = bf.Adapter()
57+
>>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both")
5258
"""
5359

5460
def __init__(
55-
self, *, lower: int | float | np.ndarray = None, upper: int | float | np.ndarray = None, method: str = "default"
61+
self,
62+
*,
63+
lower: int | float | np.ndarray = None,
64+
upper: int | float | np.ndarray = None,
65+
method: str = "default",
66+
inclusive: str = "both",
67+
epsilon: float = 1e-15,
5668
):
5769
super().__init__()
5870

@@ -121,12 +133,31 @@ def unconstrain(x):
121133

122134
self.lower = lower
123135
self.upper = upper
124-
125136
self.method = method
137+
self.inclusive = inclusive
138+
self.epsilon = epsilon
126139

127140
self.constrain = constrain
128141
self.unconstrain = unconstrain
129142

143+
# do this last to avoid serialization issues
144+
match inclusive:
145+
case "lower":
146+
if lower is not None:
147+
lower = lower - epsilon
148+
case "upper":
149+
if upper is not None:
150+
upper = upper + epsilon
151+
case True | "both":
152+
if lower is not None:
153+
lower = lower - epsilon
154+
if upper is not None:
155+
upper = upper + epsilon
156+
case False | None | "none":
157+
pass
158+
case other:
159+
raise ValueError(f"Unsupported value for 'inclusive': {other!r}.")
160+
130161
@classmethod
131162
def from_config(cls, config: dict, custom_objects=None) -> "Constrain":
132163
return cls(**config)
@@ -136,6 +167,8 @@ def get_config(self) -> dict:
136167
"lower": self.lower,
137168
"upper": self.upper,
138169
"method": self.method,
170+
"inclusive": self.inclusive,
171+
"epsilon": self.epsilon,
139172
}
140173

141174
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:

tests/test_adapters/test_adapters.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,59 @@ def test_serialize_deserialize(adapter, custom_objects, random_data):
3131
deserialized_processed = deserialized(random_data)
3232
for key, value in processed.items():
3333
assert np.allclose(value, deserialized_processed[key])
34+
35+
36+
def test_constrain():
37+
# check if constraint-implied transforms are applied correctly
38+
import numpy as np
39+
import warnings
40+
from bayesflow.adapters import Adapter
41+
42+
data = {
43+
"x_lower_cont": np.random.exponential(1, size=(32, 1)),
44+
"x_upper_cont": -np.random.exponential(1, size=(32, 1)),
45+
"x_both_cont": np.random.beta(0.5, 0.5, size=(32, 1)),
46+
"x_lower_disc1": np.zeros(shape=(32, 1)),
47+
"x_lower_disc2": np.zeros(shape=(32, 1)),
48+
"x_upper_disc1": np.ones(shape=(32, 1)),
49+
"x_upper_disc2": np.ones(shape=(32, 1)),
50+
"x_both_disc1": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
51+
"x_both_disc2": np.vstack((np.zeros(shape=(16, 1)), np.ones(shape=(16, 1)))),
52+
}
53+
54+
adapter = (
55+
Adapter()
56+
.constrain("x_lower_cont", lower=0)
57+
.constrain("x_upper_cont", upper=0)
58+
.constrain("x_both_cont", lower=0, upper=1)
59+
.constrain("x_lower_disc1", lower=0, inclusive="lower")
60+
.constrain("x_lower_disc2", lower=0, inclusive="none")
61+
.constrain("x_upper_disc1", upper=1, inclusive="upper")
62+
.constrain("x_upper_disc2", upper=1, inclusive="none")
63+
.constrain("x_both_disc1", lower=0, upper=1, inclusive="both")
64+
.constrain("x_both_disc2", lower=0, upper=1, inclusive="none")
65+
)
66+
67+
with warnings.catch_warnings():
68+
warnings.simplefilter("ignore", RuntimeWarning)
69+
result = adapter(data)
70+
71+
# continuous variables should not have boundary issues
72+
assert result["x_lower_cont"].min() < 0.0
73+
assert result["x_upper_cont"].max() > 0.0
74+
assert result["x_both_cont"].min() < 0.0
75+
assert result["x_both_cont"].max() > 1.0
76+
77+
# discrete variables at the boundaries should not have issues
78+
# if inclusive is set properly
79+
assert np.isfinite(result["x_lower_disc1"].min())
80+
assert np.isfinite(result["x_upper_disc1"].max())
81+
assert np.isfinite(result["x_both_disc1"].min())
82+
assert np.isfinite(result["x_both_disc1"].max())
83+
84+
# discrete variables at the boundaries should have issues
85+
# if inclusive is not set properly
86+
assert np.isneginf(result["x_lower_disc2"][0])
87+
assert np.isinf(result["x_upper_disc2"][0])
88+
assert np.isneginf(result["x_both_disc2"][0])
89+
assert np.isinf(result["x_both_disc2"][-1])

0 commit comments

Comments
 (0)