|
| 1 | +"""Module for the R3Refinement callback.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | +from torch.nn.modules.loss import _Loss |
| 6 | +from .refinement_interface import RefinementInterface |
| 7 | +from ...label_tensor import LabelTensor |
| 8 | +from ...utils import check_consistency |
| 9 | +from ...loss import LossInterface |
| 10 | + |
| 11 | + |
| 12 | +class R3Refinement(RefinementInterface): |
| 13 | + """ |
| 14 | + PINA Implementation of an R3 Refinement Callback. |
| 15 | + """ |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None |
| 19 | + ): |
| 20 | + """ |
| 21 | + This callback implements the R3 (Retain-Resample-Release) routine for |
| 22 | + sampling new points based on adaptive search. |
| 23 | + The algorithm incrementally accumulates collocation points in regions |
| 24 | + of high PDE residuals, and releases those with low residuals. |
| 25 | + Points are sampled uniformly in all regions where sampling is needed. |
| 26 | +
|
| 27 | + .. seealso:: |
| 28 | +
|
| 29 | + Original Reference: Daw, Arka, et al. *Mitigating Propagation |
| 30 | + Failures in Physics-informed Neural Networks |
| 31 | + using Retain-Resample-Release (R3) Sampling. (2023)*. |
| 32 | + DOI: `10.48550/arXiv.2207.02338 |
| 33 | + <https://doi.org/10.48550/arXiv.2207.02338>`_ |
| 34 | +
|
| 35 | + :param int sample_every: Frequency for sampling. |
| 36 | + :param loss: Loss function |
| 37 | + :type loss: LossInterface | ~torch.nn.modules.loss._Loss |
| 38 | + :param condition_to_update: The conditions to update during the |
| 39 | + refinement process. If None, all conditions with a conditions will |
| 40 | + be updated. Default is None. |
| 41 | + :type condition_to_update: list(str) | tuple(str) | str |
| 42 | + :raises ValueError: If the condition_to_update is not a string or |
| 43 | + iterable of strings. |
| 44 | + :raises TypeError: If the residual_loss is not a subclass of |
| 45 | + torch.nn.Module. |
| 46 | +
|
| 47 | +
|
| 48 | + Example: |
| 49 | + >>> r3_callback = R3Refinement(sample_every=5) |
| 50 | + """ |
| 51 | + super().__init__(sample_every, condition_to_update) |
| 52 | + # check consistency loss |
| 53 | + check_consistency(residual_loss, (LossInterface, _Loss), subclass=True) |
| 54 | + self.loss_fn = residual_loss(reduction="none") |
| 55 | + |
| 56 | + def sample(self, current_points, condition_name, solver): |
| 57 | + """ |
| 58 | + Sample new points based on the R3 refinement strategy. |
| 59 | +
|
| 60 | + :param current_points: Current points in the domain. |
| 61 | + :param condition_name: Name of the condition to update. |
| 62 | + :param PINNInterface solver: The solver object. |
| 63 | + :return: New points sampled based on the R3 strategy. |
| 64 | + :rtype: LabelTensor |
| 65 | + """ |
| 66 | + # Compute residuals for the given condition (average over fields) |
| 67 | + condition = solver.problem.conditions[condition_name] |
| 68 | + target = solver.compute_residual( |
| 69 | + current_points.requires_grad_(True), condition.equation |
| 70 | + ) |
| 71 | + residuals = self.loss_fn(target, torch.zeros_like(target)).mean( |
| 72 | + dim=tuple(range(1, target.ndim)) |
| 73 | + ) |
| 74 | + |
| 75 | + # Prepare new points |
| 76 | + labels = current_points.labels |
| 77 | + domain_name = solver.problem.conditions[condition_name].domain |
| 78 | + domain = solver.problem.domains[domain_name] |
| 79 | + num_old_points = self.initial_population_size[condition_name] |
| 80 | + mask = (residuals > residuals.mean()).flatten() |
| 81 | + |
| 82 | + if mask.any(): # Use high-residual points |
| 83 | + pts = current_points[mask] |
| 84 | + pts.labels = labels |
| 85 | + retain_pts = len(pts) |
| 86 | + samples = domain.sample(num_old_points - retain_pts, "random") |
| 87 | + return LabelTensor.cat([pts, samples]) |
| 88 | + return domain.sample(num_old_points, "random") |
0 commit comments