diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py index c90b2953e..863dedfc1 100644 --- a/pina/callback/refinement/r3_refinement.py +++ b/pina/callback/refinement/r3_refinement.py @@ -1,8 +1,6 @@ """Module for the R3Refinement callback.""" import torch -from torch import nn -from torch.nn.modules.loss import _Loss from .refinement_interface import RefinementInterface from ...label_tensor import LabelTensor from ...utils import check_consistency @@ -11,78 +9,94 @@ class R3Refinement(RefinementInterface): """ - PINA Implementation of an R3 Refinement Callback. + PINA Implementation of the R3 Refinement Callback. + + This callback implements the R3 (Retain-Resample-Release) routine for + sampling new points based on adaptive search. + The algorithm incrementally accumulates collocation points in regions + of high PDE residuals, and releases those with low residuals. + Points are sampled uniformly in all regions where sampling is needed. + + .. seealso:: + + Original Reference: Daw, Arka, et al. *Mitigating Propagation + Failures in Physics-informed Neural Networks + using Retain-Resample-Release (R3) Sampling. (2023)*. + DOI: `10.48550/arXiv.2207.02338 + `_ + + :Example: + + >>> r3_callback = R3Refinement(sample_every=5) """ def __init__( - self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None + self, + sample_every, + residual_loss=torch.nn.L1Loss, + condition_to_update=None, ): """ - This callback implements the R3 (Retain-Resample-Release) routine for - sampling new points based on adaptive search. - The algorithm incrementally accumulates collocation points in regions - of high PDE residuals, and releases those with low residuals. - Points are sampled uniformly in all regions where sampling is needed. - - .. seealso:: - - Original Reference: Daw, Arka, et al. *Mitigating Propagation - Failures in Physics-informed Neural Networks - using Retain-Resample-Release (R3) Sampling. (2023)*. - DOI: `10.48550/arXiv.2207.02338 - `_ - - :param int sample_every: Frequency for sampling. - :param loss: Loss function - :type loss: LossInterface | ~torch.nn.modules.loss._Loss + Initialization of the :class:`R3Refinement` callback. + + :param int sample_every: The sampling frequency. + :param loss: The loss function to compute the residuals. + Default is :class:`~torch.nn.L1Loss`. + :type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss` :param condition_to_update: The conditions to update during the - refinement process. If None, all conditions with a conditions will - be updated. Default is None. + refinement process. If None, all conditions will be updated. + Default is None. :type condition_to_update: list(str) | tuple(str) | str - :raises ValueError: If the condition_to_update is not a string or - iterable of strings. + :raises ValueError: If the condition_to_update is neither a string nor + an iterable of strings. :raises TypeError: If the residual_loss is not a subclass of - torch.nn.Module. - - - Example: - >>> r3_callback = R3Refinement(sample_every=5) + :class:`~torch.nn.Module`. """ super().__init__(sample_every, condition_to_update) - # check consistency loss - check_consistency(residual_loss, (LossInterface, _Loss), subclass=True) + + # Check consistency + check_consistency( + residual_loss, + (LossInterface, torch.nn.modules.loss._Loss), + subclass=True, + ) + + # Save loss function self.loss_fn = residual_loss(reduction="none") def sample(self, current_points, condition_name, solver): """ Sample new points based on the R3 refinement strategy. - :param current_points: Current points in the domain. - :param condition_name: Name of the condition to update. - :param PINNInterface solver: The solver object. - :return: New points sampled based on the R3 strategy. + :param current_points: The current points in the domain. + :type current_points: LabelTensor | torch.Tensor + :param str condition_name: The name of the condition to update. + :param PINNInterface solver: The solver using this callback. + :return: The new samples generated by the R3 strategy. :rtype: LabelTensor """ - # Compute residuals for the given condition (average over fields) + # Retrieve condition and current points + device = solver.trainer.strategy.root_device condition = solver.problem.conditions[condition_name] - target = solver.compute_residual( - current_points.requires_grad_(True), condition.equation - ) + current_points = current_points.to(device).requires_grad_(True) + + # Compute residuals for the given condition (averaged over all fields) + target = solver.compute_residual(current_points, condition.equation) residuals = self.loss_fn(target, torch.zeros_like(target)).mean( dim=tuple(range(1, target.ndim)) ) - # Prepare new points - labels = current_points.labels + # Retrieve domain and initial population size domain_name = solver.problem.conditions[condition_name].domain domain = solver.problem.domains[domain_name] num_old_points = self.initial_population_size[condition_name] + + # Select points with residual above the mean mask = (residuals > residuals.mean()).flatten() + if mask.any(): + high_residual_pts = current_points[mask] + high_residual_pts.labels = current_points.labels + samples = domain.sample(num_old_points - len(high_residual_pts)) + return LabelTensor.cat([high_residual_pts, samples.to(device)]) - if mask.any(): # Use high-residual points - pts = current_points[mask] - pts.labels = labels - retain_pts = len(pts) - samples = domain.sample(num_old_points - retain_pts, "random") - return LabelTensor.cat([pts, samples]) return domain.sample(num_old_points, "random")