@@ -29,14 +29,13 @@ class Constrain(ElementwiseTransform):
2929 - Lower bound only methods: softplus, exp, (default = softplus)
3030 - Upper bound only methods: softplus, exp, (default = softplus)
3131 inclusive: Indicates which bounds are inclusive (or exclusive).
32+ - "both" (default): Both lower and upper bounds are inclusive.
3233 - "lower": Lower bound is inclusive, upper bound is exclusive.
3334 - "upper": Lower bound is exclusive, upper bound is inclusive.
34- - "both": Lower and upper bounds are inclusive.
35- - "none": Lower and upper bounds are exclusive.
36- - "default": Inclusive bounds are determined by the method.
37- - Double bounded methods are lower inclusive and upper exclusive.
38- - Single bounded methods are inclusive at the specified bound.
35+ - "none": Both lower and upper bounds are exclusive.
3936 epsilon: Small value to ensure inclusive bounds are not violated.
37+ Current default is 1e-15 as this ensures finite outcomes
38+ with the default transformations applied to data exactly at the boundaries.
4039
4140
4241 Examples:
@@ -64,8 +63,8 @@ def __init__(
6463 lower : int | float | np .ndarray = None ,
6564 upper : int | float | np .ndarray = None ,
6665 method : str = "default" ,
67- inclusive : str = "default " ,
68- epsilon : float = 1e-16 ,
66+ inclusive : str = "both " ,
67+ epsilon : float = 1e-15 ,
6968 ):
7069 super ().__init__ ()
7170
@@ -77,9 +76,6 @@ def __init__(
7776 if np .any (lower >= upper ):
7877 raise ValueError ("The lower bound must be strictly less than the upper bound." )
7978
80- if inclusive == "default" :
81- inclusive = "lower"
82-
8379 match method :
8480 case "default" | "sigmoid" | "expit" | "logit" :
8581
@@ -94,9 +90,6 @@ def unconstrain(x):
9490 raise TypeError (f"Expected a method name, got { other !r} ." )
9591 elif lower is not None :
9692 # lower bounded case
97- if inclusive == "default" :
98- inclusive = "lower"
99-
10093 match method :
10194 case "default" | "softplus" :
10295
@@ -118,9 +111,6 @@ def unconstrain(x):
118111 raise TypeError (f"Expected a method name, got { other !r} ." )
119112 else :
120113 # upper bounded case
121- if inclusive == "default" :
122- inclusive = "upper"
123-
124114 match method :
125115 case "default" | "softplus" :
126116
@@ -153,18 +143,16 @@ def unconstrain(x):
153143 # do this last to avoid serialization issues
154144 match inclusive :
155145 case "lower" :
156- if lower is None :
157- raise ValueError ("Inclusive bounds must be specified." )
158- lower = lower - epsilon
146+ if lower is not None :
147+ lower = lower - epsilon
159148 case "upper" :
160- if upper is None :
161- raise ValueError ("Inclusive bounds must be specified." )
162- upper = upper + epsilon
149+ if upper is not None :
150+ upper = upper + epsilon
163151 case True | "both" :
164- if lower is None or upper is None :
165- raise ValueError ( "Inclusive bounds must be specified." )
166- lower = lower - epsilon
167- upper = upper + epsilon
152+ if lower is not None :
153+ lower = lower - epsilon
154+ if upper is not None :
155+ upper = upper + epsilon
168156 case False | None | "none" :
169157 pass
170158 case other :
0 commit comments