Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 62 additions & 48 deletions pina/callback/refinement/r3_refinement.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Module for the R3Refinement callback."""

import torch
from torch import nn
from torch.nn.modules.loss import _Loss
from .refinement_interface import RefinementInterface
from ...label_tensor import LabelTensor
from ...utils import check_consistency
Expand All @@ -11,78 +9,94 @@

class R3Refinement(RefinementInterface):
"""
PINA Implementation of an R3 Refinement Callback.
PINA Implementation of the R3 Refinement Callback.

This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those with low residuals.
Points are sampled uniformly in all regions where sampling is needed.

.. seealso::

Original Reference: Daw, Arka, et al. *Mitigating Propagation
Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_

:Example:

>>> r3_callback = R3Refinement(sample_every=5)
"""

def __init__(
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None
self,
sample_every,
residual_loss=torch.nn.L1Loss,
condition_to_update=None,
):
"""
This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those with low residuals.
Points are sampled uniformly in all regions where sampling is needed.

.. seealso::

Original Reference: Daw, Arka, et al. *Mitigating Propagation
Failures in Physics-informed Neural Networks
using Retain-Resample-Release (R3) Sampling. (2023)*.
DOI: `10.48550/arXiv.2207.02338
<https://doi.org/10.48550/arXiv.2207.02338>`_

:param int sample_every: Frequency for sampling.
:param loss: Loss function
:type loss: LossInterface | ~torch.nn.modules.loss._Loss
Initialization of the :class:`R3Refinement` callback.

:param int sample_every: The sampling frequency.
:param loss: The loss function to compute the residuals.
Default is :class:`~torch.nn.L1Loss`.
:type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss`
:param condition_to_update: The conditions to update during the
refinement process. If None, all conditions with a conditions will
be updated. Default is None.
refinement process. If None, all conditions will be updated.
Default is None.
:type condition_to_update: list(str) | tuple(str) | str
:raises ValueError: If the condition_to_update is not a string or
iterable of strings.
:raises ValueError: If the condition_to_update is neither a string nor
an iterable of strings.
:raises TypeError: If the residual_loss is not a subclass of
torch.nn.Module.


Example:
>>> r3_callback = R3Refinement(sample_every=5)
:class:`~torch.nn.Module`.
"""
super().__init__(sample_every, condition_to_update)
# check consistency loss
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True)

# Check consistency
check_consistency(
residual_loss,
(LossInterface, torch.nn.modules.loss._Loss),
subclass=True,
)

# Save loss function
self.loss_fn = residual_loss(reduction="none")

def sample(self, current_points, condition_name, solver):
"""
Sample new points based on the R3 refinement strategy.

:param current_points: Current points in the domain.
:param condition_name: Name of the condition to update.
:param PINNInterface solver: The solver object.
:return: New points sampled based on the R3 strategy.
:param current_points: The current points in the domain.
:type current_points: LabelTensor | torch.Tensor
:param str condition_name: The name of the condition to update.
:param PINNInterface solver: The solver using this callback.
:return: The new samples generated by the R3 strategy.
:rtype: LabelTensor
"""
# Compute residuals for the given condition (average over fields)
# Retrieve condition and current points
device = solver.trainer.strategy.root_device
condition = solver.problem.conditions[condition_name]
target = solver.compute_residual(
current_points.requires_grad_(True), condition.equation
)
current_points = current_points.to(device).requires_grad_(True)

# Compute residuals for the given condition (averaged over all fields)
target = solver.compute_residual(current_points, condition.equation)
residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
dim=tuple(range(1, target.ndim))
)

# Prepare new points
labels = current_points.labels
# Retrieve domain and initial population size
domain_name = solver.problem.conditions[condition_name].domain
domain = solver.problem.domains[domain_name]
num_old_points = self.initial_population_size[condition_name]

# Select points with residual above the mean
mask = (residuals > residuals.mean()).flatten()
if mask.any():
high_residual_pts = current_points[mask]
high_residual_pts.labels = current_points.labels
samples = domain.sample(num_old_points - len(high_residual_pts))
return LabelTensor.cat([high_residual_pts, samples.to(device)])

if mask.any(): # Use high-residual points
pts = current_points[mask]
pts.labels = labels
retain_pts = len(pts)
samples = domain.sample(num_old_points - retain_pts, "random")
return LabelTensor.cat([pts, samples])
return domain.sample(num_old_points, "random")