Skip to content

Commit 88f0258

Browse files
authored
Add Learner2D loss function 'thresholded_loss_factory' (#437)
* Add Learner2D loss function 'thresholded_loss_factory' * Fix * simplify * Rename and test * test
1 parent 067444f commit 88f0258

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

adaptive/learner/learner2D.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,65 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
231231
return losses
232232

233233

234+
def thresholded_loss_function(
235+
lower_threshold: float | None = None,
236+
upper_threshold: float | None = None,
237+
priority_factor: float = 0.1,
238+
) -> Callable[[LinearNDInterpolator], np.ndarray]:
239+
"""
240+
Factory function to create a custom loss function that deprioritizes
241+
values above an upper threshold and below a lower threshold.
242+
243+
Parameters
244+
----------
245+
lower_threshold : float, optional
246+
The lower threshold for deprioritizing values. If None (default),
247+
there is no lower threshold.
248+
upper_threshold : float, optional
249+
The upper threshold for deprioritizing values. If None (default),
250+
there is no upper threshold.
251+
priority_factor : float, default: 0.1
252+
The factor by which the loss is multiplied for values outside
253+
the specified thresholds.
254+
255+
Returns
256+
-------
257+
custom_loss : Callable[[LinearNDInterpolator], np.ndarray]
258+
A custom loss function that can be used with Learner2D.
259+
"""
260+
261+
def custom_loss(ip: LinearNDInterpolator) -> np.ndarray:
262+
"""Loss function that deprioritizes values outside an upper and lower threshold.
263+
264+
Parameters
265+
----------
266+
ip : `scipy.interpolate.LinearNDInterpolator` instance
267+
268+
Returns
269+
-------
270+
losses : numpy.ndarray
271+
Loss per triangle in ``ip.tri``.
272+
"""
273+
losses = default_loss(ip)
274+
275+
if lower_threshold is not None or upper_threshold is not None:
276+
simplices = ip.tri.simplices
277+
values = ip.values[simplices]
278+
if lower_threshold is not None:
279+
mask_lower = (values < lower_threshold).all(axis=(1, -1))
280+
if mask_lower.any():
281+
losses[mask_lower] *= priority_factor
282+
283+
if upper_threshold is not None:
284+
mask_upper = (values > upper_threshold).all(axis=(1, -1))
285+
if mask_upper.any():
286+
losses[mask_upper] *= priority_factor
287+
288+
return losses
289+
290+
return custom_loss
291+
292+
234293
def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarray:
235294
"""Choose a new point in inside a triangle.
236295

adaptive/tests/test_learners.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
adaptive.learner.learner2D.uniform_loss,
5454
adaptive.learner.learner2D.minimize_triangle_surface_loss,
5555
adaptive.learner.learner2D.resolution_loss_function(),
56+
adaptive.learner.learner2D.thresholded_loss_function(upper_threshold=0.5),
5657
),
5758
),
5859
LearnerND: (

0 commit comments

Comments
 (0)