Skip to content

Commit ac0437b

Browse files
Fix adaptive refinement (#571)
--------- Co-authored-by: Dario Coscia <[email protected]>
1 parent 6b355b4 commit ac0437b

File tree

10 files changed

+326
-216
lines changed

10 files changed

+326
-216
lines changed

docs/source/_rst/_code.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ Callbacks
238238

239239
Processing callback <callback/processing_callback.rst>
240240
Optimizer callback <callback/optimizer_callback.rst>
241-
Refinment callback <callback/adaptive_refinment_callback.rst>
241+
R3 Refinment callback <callback/refinement/r3_refinement.rst>
242+
Refinment Interface callback <callback/refinement/refinement_interface.rst>
242243
Weighting callback <callback/linear_weight_update_callback.rst>
243244

244245
Losses and Weightings

docs/source/_rst/callback/adaptive_refinment_callback.rst renamed to docs/source/_rst/callback/refinement/r3_refinement.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Refinments callbacks
22
=======================
33

4-
.. currentmodule:: pina.callback.adaptive_refinement_callback
4+
.. currentmodule:: pina.callback.refinement
55
.. autoclass:: R3Refinement
66
:members:
77
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Refinement Interface
2+
=======================
3+
4+
.. currentmodule:: pina.callback.refinement
5+
.. autoclass:: RefinementInterface
6+
:members:
7+
:show-inheritance:

pina/callback/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
__all__ = [
44
"SwitchOptimizer",
5-
"R3Refinement",
65
"MetricTracker",
76
"PINAProgressBar",
87
"LinearWeightUpdate",
8+
"R3Refinement",
99
]
1010

1111
from .optimizer_callback import SwitchOptimizer
12-
from .adaptive_refinement_callback import R3Refinement
1312
from .processing_callback import MetricTracker, PINAProgressBar
1413
from .linear_weight_update_callback import LinearWeightUpdate
14+
from .refinement import R3Refinement

pina/callback/adaptive_refinement_callback.py

Lines changed: 0 additions & 181 deletions
This file was deleted.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
Module for Pina Refinement callbacks.
3+
"""
4+
5+
__all__ = [
6+
"RefinementInterface",
7+
"R3Refinement",
8+
]
9+
10+
from .refinement_interface import RefinementInterface
11+
from .r3_refinement import R3Refinement
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Module for the R3Refinement callback."""
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn.modules.loss import _Loss
6+
from .refinement_interface import RefinementInterface
7+
from ...label_tensor import LabelTensor
8+
from ...utils import check_consistency
9+
from ...loss import LossInterface
10+
11+
12+
class R3Refinement(RefinementInterface):
13+
"""
14+
PINA Implementation of an R3 Refinement Callback.
15+
"""
16+
17+
def __init__(
18+
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None
19+
):
20+
"""
21+
This callback implements the R3 (Retain-Resample-Release) routine for
22+
sampling new points based on adaptive search.
23+
The algorithm incrementally accumulates collocation points in regions
24+
of high PDE residuals, and releases those with low residuals.
25+
Points are sampled uniformly in all regions where sampling is needed.
26+
27+
.. seealso::
28+
29+
Original Reference: Daw, Arka, et al. *Mitigating Propagation
30+
Failures in Physics-informed Neural Networks
31+
using Retain-Resample-Release (R3) Sampling. (2023)*.
32+
DOI: `10.48550/arXiv.2207.02338
33+
<https://doi.org/10.48550/arXiv.2207.02338>`_
34+
35+
:param int sample_every: Frequency for sampling.
36+
:param loss: Loss function
37+
:type loss: LossInterface | ~torch.nn.modules.loss._Loss
38+
:param condition_to_update: The conditions to update during the
39+
refinement process. If None, all conditions with a conditions will
40+
be updated. Default is None.
41+
:type condition_to_update: list(str) | tuple(str) | str
42+
:raises ValueError: If the condition_to_update is not a string or
43+
iterable of strings.
44+
:raises TypeError: If the residual_loss is not a subclass of
45+
torch.nn.Module.
46+
47+
48+
Example:
49+
>>> r3_callback = R3Refinement(sample_every=5)
50+
"""
51+
super().__init__(sample_every, condition_to_update)
52+
# check consistency loss
53+
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True)
54+
self.loss_fn = residual_loss(reduction="none")
55+
56+
def sample(self, current_points, condition_name, solver):
57+
"""
58+
Sample new points based on the R3 refinement strategy.
59+
60+
:param current_points: Current points in the domain.
61+
:param condition_name: Name of the condition to update.
62+
:param PINNInterface solver: The solver object.
63+
:return: New points sampled based on the R3 strategy.
64+
:rtype: LabelTensor
65+
"""
66+
# Compute residuals for the given condition (average over fields)
67+
condition = solver.problem.conditions[condition_name]
68+
target = solver.compute_residual(
69+
current_points.requires_grad_(True), condition.equation
70+
)
71+
residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
72+
dim=tuple(range(1, target.ndim))
73+
)
74+
75+
# Prepare new points
76+
labels = current_points.labels
77+
domain_name = solver.problem.conditions[condition_name].domain
78+
domain = solver.problem.domains[domain_name]
79+
num_old_points = self.initial_population_size[condition_name]
80+
mask = (residuals > residuals.mean()).flatten()
81+
82+
if mask.any(): # Use high-residual points
83+
pts = current_points[mask]
84+
pts.labels = labels
85+
retain_pts = len(pts)
86+
samples = domain.sample(num_old_points - retain_pts, "random")
87+
return LabelTensor.cat([pts, samples])
88+
return domain.sample(num_old_points, "random")

0 commit comments

Comments
 (0)