Skip to content

Commit 9bd60f8

Browse files
committed
allow inclusive bounds in adapter constrain
1 parent 051f998 commit 9bd60f8

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

bayesflow/adapters/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,17 @@ def constrain(
233233
lower: int | float | np.ndarray = None,
234234
upper: int | float | np.ndarray = None,
235235
method: str = "default",
236+
inclusive: str = "default",
237+
epsilon: float = 1e-16,
236238
):
237239
if isinstance(keys, str):
238240
keys = [keys]
239241

240242
transform = MapTransform(
241-
transform_map={key: Constrain(lower=lower, upper=upper, method=method) for key in keys}
243+
transform_map={
244+
key: Constrain(lower=lower, upper=upper, method=method, inclusive=inclusive, epsilon=epsilon)
245+
for key in keys
246+
}
242247
)
243248
self.transforms.append(transform)
244249
return self

bayesflow/adapters/transforms/constrain.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@serializable(package="bayesflow.adapters")
1717
class Constrain(ElementwiseTransform):
1818
"""
19-
Constrains neural network predictions of a data variable to specificied bounds.
19+
Constrains neural network predictions of a data variable to specified bounds.
2020
2121
Parameters:
2222
String containing the name of the data variable to be transformed e.g. "sigma". See examples below.
@@ -28,14 +28,22 @@ class Constrain(ElementwiseTransform):
2828
- Double bounded methods: sigmoid, expit, (default = sigmoid)
2929
- Lower bound only methods: softplus, exp, (default = softplus)
3030
- Upper bound only methods: softplus, exp, (default = softplus)
31-
31+
inclusive: Indicates which bounds are inclusive (or exclusive).
32+
- "lower": Lower bound is inclusive, upper bound is exclusive.
33+
- "upper": Lower bound is exclusive, upper bound is inclusive.
34+
- "both": Lower and upper bounds are inclusive.
35+
- "none": Lower and upper bounds are exclusive.
36+
- "default": Inclusive bounds are determined by the method.
37+
- Double bounded methods are lower inclusive and upper exclusive.
38+
- Single bounded methods are inclusive at the specified bound.
39+
epsilon: Small value to ensure inclusive bounds are not violated.
3240
3341
3442
Examples:
3543
1) Let sigma be the standard deviation of a normal distribution,
3644
then sigma should always be greater than zero.
3745
38-
Useage:
46+
Usage:
3947
adapter = (
4048
bf.Adapter()
4149
.constrain("sigma", lower=0)
@@ -45,14 +53,19 @@ class Constrain(ElementwiseTransform):
4553
[0,1] then we would constrain the neural network to estimate p in the following way.
4654
4755
Usage:
48-
adapter = (
49-
bf.Adapter()
50-
.constrain("p", lower=0, upper=1, method = "sigmoid")
51-
)
56+
>>> import bayesflow as bf
57+
>>> adapter = bf.Adapter()
58+
>>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both")
5259
"""
5360

5461
def __init__(
55-
self, *, lower: int | float | np.ndarray = None, upper: int | float | np.ndarray = None, method: str = "default"
62+
self,
63+
*,
64+
lower: int | float | np.ndarray = None,
65+
upper: int | float | np.ndarray = None,
66+
method: str = "default",
67+
inclusive: str = "default",
68+
epsilon: float = 1e-16,
5669
):
5770
super().__init__()
5871

@@ -64,6 +77,9 @@ def __init__(
6477
if np.any(lower >= upper):
6578
raise ValueError("The lower bound must be strictly less than the upper bound.")
6679

80+
if inclusive == "default":
81+
inclusive = "lower"
82+
6783
match method:
6884
case "default" | "sigmoid" | "expit" | "logit":
6985

@@ -78,6 +94,9 @@ def unconstrain(x):
7894
raise TypeError(f"Expected a method name, got {other!r}.")
7995
elif lower is not None:
8096
# lower bounded case
97+
if inclusive == "default":
98+
inclusive = "lower"
99+
81100
match method:
82101
case "default" | "softplus":
83102

@@ -99,6 +118,9 @@ def unconstrain(x):
99118
raise TypeError(f"Expected a method name, got {other!r}.")
100119
else:
101120
# upper bounded case
121+
if inclusive == "default":
122+
inclusive = "upper"
123+
102124
match method:
103125
case "default" | "softplus":
104126

@@ -119,6 +141,25 @@ def unconstrain(x):
119141
case other:
120142
raise TypeError(f"Expected a method name, got {other!r}.")
121143

144+
match inclusive:
145+
case "lower":
146+
if lower is None:
147+
raise ValueError("Inclusive bounds must be specified.")
148+
lower -= epsilon
149+
case "upper":
150+
if upper is None:
151+
raise ValueError("Inclusive bounds must be specified.")
152+
upper += epsilon
153+
case True | "both":
154+
if lower is None or upper is None:
155+
raise ValueError("Inclusive bounds must be specified.")
156+
lower -= epsilon
157+
upper += epsilon
158+
case False | None | "none":
159+
pass
160+
case other:
161+
raise ValueError(f"Unsupported value for 'inclusive': {other!r}.")
162+
122163
self.lower = lower
123164
self.upper = upper
124165

0 commit comments

Comments
 (0)