|
1 | 1 | """Module for the R3Refinement callback.""" |
2 | 2 |
|
3 | 3 | import torch |
4 | | -from torch import nn |
5 | | -from torch.nn.modules.loss import _Loss |
6 | 4 | from .refinement_interface import RefinementInterface |
7 | 5 | from ...label_tensor import LabelTensor |
8 | 6 | from ...utils import check_consistency |
|
11 | 9 |
|
12 | 10 | class R3Refinement(RefinementInterface): |
13 | 11 | """ |
14 | | - PINA Implementation of an R3 Refinement Callback. |
| 12 | + PINA Implementation of the R3 Refinement Callback. |
| 13 | +
|
| 14 | + This callback implements the R3 (Retain-Resample-Release) routine for |
| 15 | + sampling new points based on adaptive search. |
| 16 | + The algorithm incrementally accumulates collocation points in regions |
| 17 | + of high PDE residuals, and releases those with low residuals. |
| 18 | + Points are sampled uniformly in all regions where sampling is needed. |
| 19 | +
|
| 20 | + .. seealso:: |
| 21 | +
|
| 22 | + Original Reference: Daw, Arka, et al. *Mitigating Propagation |
| 23 | + Failures in Physics-informed Neural Networks |
| 24 | + using Retain-Resample-Release (R3) Sampling. (2023)*. |
| 25 | + DOI: `10.48550/arXiv.2207.02338 |
| 26 | + <https://doi.org/10.48550/arXiv.2207.02338>`_ |
| 27 | +
|
| 28 | + :Example: |
| 29 | +
|
| 30 | + >>> r3_callback = R3Refinement(sample_every=5) |
15 | 31 | """ |
16 | 32 |
|
17 | 33 | def __init__( |
18 | | - self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None |
| 34 | + self, |
| 35 | + sample_every, |
| 36 | + residual_loss=torch.nn.L1Loss, |
| 37 | + condition_to_update=None, |
19 | 38 | ): |
20 | 39 | """ |
21 | | - This callback implements the R3 (Retain-Resample-Release) routine for |
22 | | - sampling new points based on adaptive search. |
23 | | - The algorithm incrementally accumulates collocation points in regions |
24 | | - of high PDE residuals, and releases those with low residuals. |
25 | | - Points are sampled uniformly in all regions where sampling is needed. |
26 | | -
|
27 | | - .. seealso:: |
28 | | -
|
29 | | - Original Reference: Daw, Arka, et al. *Mitigating Propagation |
30 | | - Failures in Physics-informed Neural Networks |
31 | | - using Retain-Resample-Release (R3) Sampling. (2023)*. |
32 | | - DOI: `10.48550/arXiv.2207.02338 |
33 | | - <https://doi.org/10.48550/arXiv.2207.02338>`_ |
34 | | -
|
35 | | - :param int sample_every: Frequency for sampling. |
36 | | - :param loss: Loss function |
37 | | - :type loss: LossInterface | ~torch.nn.modules.loss._Loss |
| 40 | + Initialization of the :class:`R3Refinement` callback. |
| 41 | +
|
| 42 | + :param int sample_every: The sampling frequency. |
| 43 | + :param loss: The loss function to compute the residuals. |
| 44 | + Default is :class:`~torch.nn.L1Loss`. |
| 45 | + :type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss` |
38 | 46 | :param condition_to_update: The conditions to update during the |
39 | | - refinement process. If None, all conditions with a conditions will |
40 | | - be updated. Default is None. |
| 47 | + refinement process. If None, all conditions will be updated. |
| 48 | + Default is None. |
41 | 49 | :type condition_to_update: list(str) | tuple(str) | str |
42 | | - :raises ValueError: If the condition_to_update is not a string or |
43 | | - iterable of strings. |
| 50 | + :raises ValueError: If the condition_to_update is neither a string nor |
| 51 | + an iterable of strings. |
44 | 52 | :raises TypeError: If the residual_loss is not a subclass of |
45 | | - torch.nn.Module. |
46 | | -
|
47 | | -
|
48 | | - Example: |
49 | | - >>> r3_callback = R3Refinement(sample_every=5) |
| 53 | + :class:`~torch.nn.Module`. |
50 | 54 | """ |
51 | 55 | super().__init__(sample_every, condition_to_update) |
52 | | - # check consistency loss |
53 | | - check_consistency(residual_loss, (LossInterface, _Loss), subclass=True) |
| 56 | + |
| 57 | + # Check consistency |
| 58 | + check_consistency( |
| 59 | + residual_loss, |
| 60 | + (LossInterface, torch.nn.modules.loss._Loss), |
| 61 | + subclass=True, |
| 62 | + ) |
| 63 | + |
| 64 | + # Save loss function |
54 | 65 | self.loss_fn = residual_loss(reduction="none") |
55 | 66 |
|
56 | 67 | def sample(self, current_points, condition_name, solver): |
57 | 68 | """ |
58 | 69 | Sample new points based on the R3 refinement strategy. |
59 | 70 |
|
60 | | - :param current_points: Current points in the domain. |
61 | | - :param condition_name: Name of the condition to update. |
62 | | - :param PINNInterface solver: The solver object. |
63 | | - :return: New points sampled based on the R3 strategy. |
| 71 | + :param current_points: The current points in the domain. |
| 72 | + :type current_points: LabelTensor | torch.Tensor |
| 73 | + :param str condition_name: The name of the condition to update. |
| 74 | + :param PINNInterface solver: The solver using this callback. |
| 75 | + :return: The new samples generated by the R3 strategy. |
64 | 76 | :rtype: LabelTensor |
65 | 77 | """ |
66 | | - # Compute residuals for the given condition (average over fields) |
| 78 | + # Retrieve condition and current points |
| 79 | + device = solver.trainer.strategy.root_device |
67 | 80 | condition = solver.problem.conditions[condition_name] |
68 | | - target = solver.compute_residual( |
69 | | - current_points.requires_grad_(True), condition.equation |
70 | | - ) |
| 81 | + current_points = current_points.to(device).requires_grad_(True) |
| 82 | + |
| 83 | + # Compute residuals for the given condition (averaged over all fields) |
| 84 | + target = solver.compute_residual(current_points, condition.equation) |
71 | 85 | residuals = self.loss_fn(target, torch.zeros_like(target)).mean( |
72 | 86 | dim=tuple(range(1, target.ndim)) |
73 | 87 | ) |
74 | 88 |
|
75 | | - # Prepare new points |
76 | | - labels = current_points.labels |
| 89 | + # Retrieve domain and initial population size |
77 | 90 | domain_name = solver.problem.conditions[condition_name].domain |
78 | 91 | domain = solver.problem.domains[domain_name] |
79 | 92 | num_old_points = self.initial_population_size[condition_name] |
| 93 | + |
| 94 | + # Select points with residual above the mean |
80 | 95 | mask = (residuals > residuals.mean()).flatten() |
| 96 | + if mask.any(): |
| 97 | + high_residual_pts = current_points[mask] |
| 98 | + high_residual_pts.labels = current_points.labels |
| 99 | + samples = domain.sample(num_old_points - len(high_residual_pts)) |
| 100 | + return LabelTensor.cat([high_residual_pts, samples.to(device)]) |
81 | 101 |
|
82 | | - if mask.any(): # Use high-residual points |
83 | | - pts = current_points[mask] |
84 | | - pts.labels = labels |
85 | | - retain_pts = len(pts) |
86 | | - samples = domain.sample(num_old_points - retain_pts, "random") |
87 | | - return LabelTensor.cat([pts, samples]) |
88 | 102 | return domain.sample(num_old_points, "random") |
0 commit comments