Skip to content

Commit 9d65060

Browse files
move samples to device
1 parent 664b058 commit 9d65060

File tree

1 file changed

+62
-48
lines changed

1 file changed

+62
-48
lines changed
Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Module for the R3Refinement callback."""
22

33
import torch
4-
from torch import nn
5-
from torch.nn.modules.loss import _Loss
64
from .refinement_interface import RefinementInterface
75
from ...label_tensor import LabelTensor
86
from ...utils import check_consistency
@@ -11,78 +9,94 @@
119

1210
class R3Refinement(RefinementInterface):
1311
"""
14-
PINA Implementation of an R3 Refinement Callback.
12+
PINA Implementation of the R3 Refinement Callback.
13+
14+
This callback implements the R3 (Retain-Resample-Release) routine for
15+
sampling new points based on adaptive search.
16+
The algorithm incrementally accumulates collocation points in regions
17+
of high PDE residuals, and releases those with low residuals.
18+
Points are sampled uniformly in all regions where sampling is needed.
19+
20+
.. seealso::
21+
22+
Original Reference: Daw, Arka, et al. *Mitigating Propagation
23+
Failures in Physics-informed Neural Networks
24+
using Retain-Resample-Release (R3) Sampling. (2023)*.
25+
DOI: `10.48550/arXiv.2207.02338
26+
<https://doi.org/10.48550/arXiv.2207.02338>`_
27+
28+
:Example:
29+
30+
>>> r3_callback = R3Refinement(sample_every=5)
1531
"""
1632

1733
def __init__(
18-
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None
34+
self,
35+
sample_every,
36+
residual_loss=torch.nn.L1Loss,
37+
condition_to_update=None,
1938
):
2039
"""
21-
This callback implements the R3 (Retain-Resample-Release) routine for
22-
sampling new points based on adaptive search.
23-
The algorithm incrementally accumulates collocation points in regions
24-
of high PDE residuals, and releases those with low residuals.
25-
Points are sampled uniformly in all regions where sampling is needed.
26-
27-
.. seealso::
28-
29-
Original Reference: Daw, Arka, et al. *Mitigating Propagation
30-
Failures in Physics-informed Neural Networks
31-
using Retain-Resample-Release (R3) Sampling. (2023)*.
32-
DOI: `10.48550/arXiv.2207.02338
33-
<https://doi.org/10.48550/arXiv.2207.02338>`_
34-
35-
:param int sample_every: Frequency for sampling.
36-
:param loss: Loss function
37-
:type loss: LossInterface | ~torch.nn.modules.loss._Loss
40+
Initialization of the :class:`R3Refinement` callback.
41+
42+
:param int sample_every: The sampling frequency.
43+
:param loss: The loss function to compute the residuals.
44+
Default is :class:`~torch.nn.L1Loss`.
45+
:type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss`
3846
:param condition_to_update: The conditions to update during the
39-
refinement process. If None, all conditions with a conditions will
40-
be updated. Default is None.
47+
refinement process. If None, all conditions will be updated.
48+
Default is None.
4149
:type condition_to_update: list(str) | tuple(str) | str
42-
:raises ValueError: If the condition_to_update is not a string or
43-
iterable of strings.
50+
:raises ValueError: If the condition_to_update is neither a string nor
51+
an iterable of strings.
4452
:raises TypeError: If the residual_loss is not a subclass of
45-
torch.nn.Module.
46-
47-
48-
Example:
49-
>>> r3_callback = R3Refinement(sample_every=5)
53+
:class:`~torch.nn.Module`.
5054
"""
5155
super().__init__(sample_every, condition_to_update)
52-
# check consistency loss
53-
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True)
56+
57+
# Check consistency
58+
check_consistency(
59+
residual_loss,
60+
(LossInterface, torch.nn.modules.loss._Loss),
61+
subclass=True,
62+
)
63+
64+
# Save loss function
5465
self.loss_fn = residual_loss(reduction="none")
5566

5667
def sample(self, current_points, condition_name, solver):
5768
"""
5869
Sample new points based on the R3 refinement strategy.
5970
60-
:param current_points: Current points in the domain.
61-
:param condition_name: Name of the condition to update.
62-
:param PINNInterface solver: The solver object.
63-
:return: New points sampled based on the R3 strategy.
71+
:param current_points: The current points in the domain.
72+
:type current_points: LabelTensor | torch.Tensor
73+
:param str condition_name: The name of the condition to update.
74+
:param PINNInterface solver: The solver using this callback.
75+
:return: The new samples generated by the R3 strategy.
6476
:rtype: LabelTensor
6577
"""
66-
# Compute residuals for the given condition (average over fields)
78+
# Retrieve condition and current points
79+
device = solver.trainer.strategy.root_device
6780
condition = solver.problem.conditions[condition_name]
68-
target = solver.compute_residual(
69-
current_points.requires_grad_(True), condition.equation
70-
)
81+
current_points = current_points.to(device).requires_grad_(True)
82+
83+
# Compute residuals for the given condition (averaged over all fields)
84+
target = solver.compute_residual(current_points, condition.equation)
7185
residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
7286
dim=tuple(range(1, target.ndim))
7387
)
7488

75-
# Prepare new points
76-
labels = current_points.labels
89+
# Retrieve domain and initial population size
7790
domain_name = solver.problem.conditions[condition_name].domain
7891
domain = solver.problem.domains[domain_name]
7992
num_old_points = self.initial_population_size[condition_name]
93+
94+
# Select points with residual above the mean
8095
mask = (residuals > residuals.mean()).flatten()
96+
if mask.any():
97+
high_residual_pts = current_points[mask]
98+
high_residual_pts.labels = current_points.labels
99+
samples = domain.sample(num_old_points - len(high_residual_pts))
100+
return LabelTensor.cat([high_residual_pts, samples.to(device)])
81101

82-
if mask.any(): # Use high-residual points
83-
pts = current_points[mask]
84-
pts.labels = labels
85-
retain_pts = len(pts)
86-
samples = domain.sample(num_old_points - retain_pts, "random")
87-
return LabelTensor.cat([pts, samples])
88102
return domain.sample(num_old_points, "random")

0 commit comments

Comments
 (0)