@@ -231,6 +231,65 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
231
231
return losses
232
232
233
233
234
+ def thresholded_loss_function (
235
+ lower_threshold : float | None = None ,
236
+ upper_threshold : float | None = None ,
237
+ priority_factor : float = 0.1 ,
238
+ ) -> Callable [[LinearNDInterpolator ], np .ndarray ]:
239
+ """
240
+ Factory function to create a custom loss function that deprioritizes
241
+ values above an upper threshold and below a lower threshold.
242
+
243
+ Parameters
244
+ ----------
245
+ lower_threshold : float, optional
246
+ The lower threshold for deprioritizing values. If None (default),
247
+ there is no lower threshold.
248
+ upper_threshold : float, optional
249
+ The upper threshold for deprioritizing values. If None (default),
250
+ there is no upper threshold.
251
+ priority_factor : float, default: 0.1
252
+ The factor by which the loss is multiplied for values outside
253
+ the specified thresholds.
254
+
255
+ Returns
256
+ -------
257
+ custom_loss : Callable[[LinearNDInterpolator], np.ndarray]
258
+ A custom loss function that can be used with Learner2D.
259
+ """
260
+
261
+ def custom_loss (ip : LinearNDInterpolator ) -> np .ndarray :
262
+ """Loss function that deprioritizes values outside an upper and lower threshold.
263
+
264
+ Parameters
265
+ ----------
266
+ ip : `scipy.interpolate.LinearNDInterpolator` instance
267
+
268
+ Returns
269
+ -------
270
+ losses : numpy.ndarray
271
+ Loss per triangle in ``ip.tri``.
272
+ """
273
+ losses = default_loss (ip )
274
+
275
+ if lower_threshold is not None or upper_threshold is not None :
276
+ simplices = ip .tri .simplices
277
+ values = ip .values [simplices ]
278
+ if lower_threshold is not None :
279
+ mask_lower = (values < lower_threshold ).all (axis = (1 , - 1 ))
280
+ if mask_lower .any ():
281
+ losses [mask_lower ] *= priority_factor
282
+
283
+ if upper_threshold is not None :
284
+ mask_upper = (values > upper_threshold ).all (axis = (1 , - 1 ))
285
+ if mask_upper .any ():
286
+ losses [mask_upper ] *= priority_factor
287
+
288
+ return losses
289
+
290
+ return custom_loss
291
+
292
+
234
293
def choose_point_in_triangle (triangle : np .ndarray , max_badness : int ) -> np .ndarray :
235
294
"""Choose a new point in inside a triangle.
236
295
0 commit comments