1616@serializable (package = "bayesflow.adapters" )
1717class Constrain (ElementwiseTransform ):
1818 """
19- Constrains neural network predictions of a data variable to specificied bounds.
19+ Constrains neural network predictions of a data variable to specified bounds.
2020
2121 Parameters:
2222 String containing the name of the data variable to be transformed e.g. "sigma". See examples below.
@@ -28,14 +28,22 @@ class Constrain(ElementwiseTransform):
2828 - Double bounded methods: sigmoid, expit, (default = sigmoid)
2929 - Lower bound only methods: softplus, exp, (default = softplus)
3030 - Upper bound only methods: softplus, exp, (default = softplus)
31-
31+ inclusive: Indicates which bounds are inclusive (or exclusive).
32+ - "lower": Lower bound is inclusive, upper bound is exclusive.
33+ - "upper": Lower bound is exclusive, upper bound is inclusive.
34+ - "both": Lower and upper bounds are inclusive.
35+ - "none": Lower and upper bounds are exclusive.
36+ - "default": Inclusive bounds are determined by the method.
37+ - Double bounded methods are lower inclusive and upper exclusive.
38+ - Single bounded methods are inclusive at the specified bound.
39+ epsilon: Small value to ensure inclusive bounds are not violated.
3240
3341
3442 Examples:
3543 1) Let sigma be the standard deviation of a normal distribution,
3644 then sigma should always be greater than zero.
3745
38- Useage :
46+ Usage :
3947 adapter = (
4048 bf.Adapter()
4149 .constrain("sigma", lower=0)
@@ -45,14 +53,19 @@ class Constrain(ElementwiseTransform):
4553 [0,1] then we would constrain the neural network to estimate p in the following way.
4654
4755 Usage:
48- adapter = (
49- bf.Adapter()
50- .constrain("p", lower=0, upper=1, method = "sigmoid")
51- )
56+ >>> import bayesflow as bf
57+ >>> adapter = bf.Adapter()
58+ >>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both")
5259 """
5360
5461 def __init__ (
55- self , * , lower : int | float | np .ndarray = None , upper : int | float | np .ndarray = None , method : str = "default"
62+ self ,
63+ * ,
64+ lower : int | float | np .ndarray = None ,
65+ upper : int | float | np .ndarray = None ,
66+ method : str = "default" ,
67+ inclusive : str = "default" ,
68+ epsilon : float = 1e-16 ,
5669 ):
5770 super ().__init__ ()
5871
@@ -64,6 +77,9 @@ def __init__(
6477 if np .any (lower >= upper ):
6578 raise ValueError ("The lower bound must be strictly less than the upper bound." )
6679
80+ if inclusive == "default" :
81+ inclusive = "lower"
82+
6783 match method :
6884 case "default" | "sigmoid" | "expit" | "logit" :
6985
@@ -78,6 +94,9 @@ def unconstrain(x):
7894 raise TypeError (f"Expected a method name, got { other !r} ." )
7995 elif lower is not None :
8096 # lower bounded case
97+ if inclusive == "default" :
98+ inclusive = "lower"
99+
81100 match method :
82101 case "default" | "softplus" :
83102
@@ -99,6 +118,9 @@ def unconstrain(x):
99118 raise TypeError (f"Expected a method name, got { other !r} ." )
100119 else :
101120 # upper bounded case
121+ if inclusive == "default" :
122+ inclusive = "upper"
123+
102124 match method :
103125 case "default" | "softplus" :
104126
@@ -119,6 +141,25 @@ def unconstrain(x):
119141 case other :
120142 raise TypeError (f"Expected a method name, got { other !r} ." )
121143
144+ match inclusive :
145+ case "lower" :
146+ if lower is None :
147+ raise ValueError ("Inclusive bounds must be specified." )
148+ lower -= epsilon
149+ case "upper" :
150+ if upper is None :
151+ raise ValueError ("Inclusive bounds must be specified." )
152+ upper += epsilon
153+ case True | "both" :
154+ if lower is None or upper is None :
155+ raise ValueError ("Inclusive bounds must be specified." )
156+ lower -= epsilon
157+ upper += epsilon
158+ case False | None | "none" :
159+ pass
160+ case other :
161+ raise ValueError (f"Unsupported value for 'inclusive': { other !r} ." )
162+
122163 self .lower = lower
123164 self .upper = upper
124165
0 commit comments