1616@serializable (package = "bayesflow.adapters" )
1717class Constrain (ElementwiseTransform ):
1818 """
19- Constrains neural network predictions of a data variable to specificied bounds.
19+ Constrains neural network predictions of a data variable to specified bounds.
2020
2121 Parameters:
2222 String containing the name of the data variable to be transformed e.g. "sigma". See examples below.
@@ -28,14 +28,21 @@ class Constrain(ElementwiseTransform):
2828 - Double bounded methods: sigmoid, expit, (default = sigmoid)
2929 - Lower bound only methods: softplus, exp, (default = softplus)
3030 - Upper bound only methods: softplus, exp, (default = softplus)
31-
31+ inclusive: Indicates which bounds are inclusive (or exclusive).
32+ - "both" (default): Both lower and upper bounds are inclusive.
33+ - "lower": Lower bound is inclusive, upper bound is exclusive.
34+ - "upper": Lower bound is exclusive, upper bound is inclusive.
35+ - "none": Both lower and upper bounds are exclusive.
36+ epsilon: Small value to ensure inclusive bounds are not violated.
37+ Current default is 1e-15 as this ensures finite outcomes
38+ with the default transformations applied to data exactly at the boundaries.
3239
3340
3441 Examples:
3542 1) Let sigma be the standard deviation of a normal distribution,
3643 then sigma should always be greater than zero.
3744
38- Useage :
45+ Usage :
3946 adapter = (
4047 bf.Adapter()
4148 .constrain("sigma", lower=0)
@@ -45,14 +52,19 @@ class Constrain(ElementwiseTransform):
4552 [0,1] then we would constrain the neural network to estimate p in the following way.
4653
4754 Usage:
48- adapter = (
49- bf.Adapter()
50- .constrain("p", lower=0, upper=1, method = "sigmoid")
51- )
55+ >>> import bayesflow as bf
56+ >>> adapter = bf.Adapter()
57+ >>> adapter.constrain("p", lower=0, upper=1, method="sigmoid", inclusive="both")
5258 """
5359
5460 def __init__ (
55- self , * , lower : int | float | np .ndarray = None , upper : int | float | np .ndarray = None , method : str = "default"
61+ self ,
62+ * ,
63+ lower : int | float | np .ndarray = None ,
64+ upper : int | float | np .ndarray = None ,
65+ method : str = "default" ,
66+ inclusive : str = "both" ,
67+ epsilon : float = 1e-15 ,
5668 ):
5769 super ().__init__ ()
5870
@@ -121,12 +133,31 @@ def unconstrain(x):
121133
122134 self .lower = lower
123135 self .upper = upper
124-
125136 self .method = method
137+ self .inclusive = inclusive
138+ self .epsilon = epsilon
126139
127140 self .constrain = constrain
128141 self .unconstrain = unconstrain
129142
143+ # do this last to avoid serialization issues
144+ match inclusive :
145+ case "lower" :
146+ if lower is not None :
147+ lower = lower - epsilon
148+ case "upper" :
149+ if upper is not None :
150+ upper = upper + epsilon
151+ case True | "both" :
152+ if lower is not None :
153+ lower = lower - epsilon
154+ if upper is not None :
155+ upper = upper + epsilon
156+ case False | None | "none" :
157+ pass
158+ case other :
159+ raise ValueError (f"Unsupported value for 'inclusive': { other !r} ." )
160+
130161 @classmethod
131162 def from_config (cls , config : dict , custom_objects = None ) -> "Constrain" :
132163 return cls (** config )
@@ -136,6 +167,8 @@ def get_config(self) -> dict:
136167 "lower" : self .lower ,
137168 "upper" : self .upper ,
138169 "method" : self .method ,
170+ "inclusive" : self .inclusive ,
171+ "epsilon" : self .epsilon ,
139172 }
140173
141174 def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
0 commit comments