@@ -141,33 +141,35 @@ def unconstrain(x):
141141 case other :
142142 raise TypeError (f"Expected a method name, got { other !r} ." )
143143
144+ self .lower = lower
145+ self .upper = upper
146+ self .method = method
147+ self .inclusive = inclusive
148+ self .epsilon = epsilon
149+
150+ self .constrain = constrain
151+ self .unconstrain = unconstrain
152+
153+ # do this last to avoid serialization issues
144154 match inclusive :
145155 case "lower" :
146156 if lower is None :
147157 raise ValueError ("Inclusive bounds must be specified." )
148- lower -= epsilon
158+ lower = lower - epsilon
149159 case "upper" :
150160 if upper is None :
151161 raise ValueError ("Inclusive bounds must be specified." )
152- upper += epsilon
162+ upper = upper + epsilon
153163 case True | "both" :
154164 if lower is None or upper is None :
155165 raise ValueError ("Inclusive bounds must be specified." )
156- lower -= epsilon
157- upper += epsilon
166+ lower = lower - epsilon
167+ upper = upper + epsilon
158168 case False | None | "none" :
159169 pass
160170 case other :
161171 raise ValueError (f"Unsupported value for 'inclusive': { other !r} ." )
162172
163- self .lower = lower
164- self .upper = upper
165-
166- self .method = method
167-
168- self .constrain = constrain
169- self .unconstrain = unconstrain
170-
171173 @classmethod
172174 def from_config (cls , config : dict , custom_objects = None ) -> "Constrain" :
173175 return cls (** config )
@@ -177,6 +179,8 @@ def get_config(self) -> dict:
177179 "lower" : self .lower ,
178180 "upper" : self .upper ,
179181 "method" : self .method ,
182+ "inclusive" : self .inclusive ,
183+ "epsilon" : self .epsilon ,
180184 }
181185
182186 def forward (self , data : np .ndarray , ** kwargs ) -> np .ndarray :
0 commit comments