Skip to content

Commit 89259b6

Browse files
Update callbacks and tests (#482)
--------- Co-authored-by: giovanni <[email protected]>
1 parent 795e4a4 commit 89259b6

File tree

8 files changed

+288
-253
lines changed

8 files changed

+288
-253
lines changed
Lines changed: 151 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
"""PINA Callbacks Implementations"""
22

3+
import importlib.metadata
34
import torch
45
from lightning.pytorch.callbacks import Callback
56
from ..label_tensor import LabelTensor
67
from ..utils import check_consistency
78

89

910
class R3Refinement(Callback):
11+
"""
12+
PINA Implementation of an R3 Refinement Callback.
13+
"""
1014

1115
def __init__(self, sample_every):
1216
"""
13-
PINA Implementation of an R3 Refinement Callback.
14-
1517
This callback implements the R3 (Retain-Resample-Release) routine for
1618
sampling new points based on adaptive search.
1719
The algorithm incrementally accumulates collocation points in regions
18-
of high PDE residuals, and releases those
19-
with low residuals. Points are sampled uniformly in all regions
20-
where sampling is needed.
20+
of high PDE residuals, and releases those with low residuals.
21+
Points are sampled uniformly in all regions where sampling is needed.
2122
2223
.. seealso::
2324
@@ -33,142 +34,148 @@ def __init__(self, sample_every):
3334
Example:
3435
>>> r3_callback = R3Refinement(sample_every=5)
3536
"""
36-
super().__init__()
37-
38-
# sample every
39-
check_consistency(sample_every, int)
40-
self._sample_every = sample_every
41-
self._const_pts = None
42-
43-
def _compute_residual(self, trainer):
44-
"""
45-
Computes the residuals for a PINN object.
46-
47-
:return: the total loss, and pointwise loss.
48-
:rtype: tuple
49-
"""
50-
51-
# extract the solver and device from trainer
52-
solver = trainer.solver
53-
device = trainer._accelerator_connector._accelerator_flag
54-
precision = trainer.precision
55-
if precision == "64-true":
56-
precision = torch.float64
57-
elif precision == "32-true":
58-
precision = torch.float32
59-
else:
60-
raise RuntimeError(
61-
"Currently R3Refinement is only implemented "
62-
"for precision '32-true' and '64-true', set "
63-
"Trainer precision to match one of the "
64-
"available precisions."
65-
)
66-
67-
# compute residual
68-
res_loss = {}
69-
tot_loss = []
70-
for location in self._sampling_locations: # TODO fix for new collector
71-
condition = solver.problem.conditions[location]
72-
pts = solver.problem.input_pts[location]
73-
# send points to correct device
74-
pts = pts.to(device=device, dtype=precision)
75-
pts = pts.requires_grad_(True)
76-
pts.retain_grad()
77-
# PINN loss: equation evaluated only for sampling locations
78-
target = condition.equation.residual(pts, solver.forward(pts))
79-
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
80-
tot_loss.append(torch.abs(target))
81-
82-
print(tot_loss)
83-
84-
return torch.vstack(tot_loss), res_loss
85-
86-
def _r3_routine(self, trainer):
87-
"""
88-
R3 refinement main routine.
89-
90-
:param Trainer trainer: PINA Trainer.
91-
"""
92-
# compute residual (all device possible)
93-
tot_loss, res_loss = self._compute_residual(trainer)
94-
tot_loss = tot_loss.as_subclass(torch.Tensor)
95-
96-
# !!!!!! From now everything is performed on CPU !!!!!!
97-
98-
# average loss
99-
avg = (tot_loss.mean()).to("cpu")
100-
old_pts = {} # points to be retained
101-
for location in self._sampling_locations:
102-
pts = trainer._model.problem.input_pts[location]
103-
labels = pts.labels
104-
pts = pts.cpu().detach().as_subclass(torch.Tensor)
105-
residuals = res_loss[location].cpu()
106-
mask = (residuals > avg).flatten()
107-
if any(mask): # append residuals greater than average
108-
pts = (pts[mask]).as_subclass(LabelTensor)
109-
pts.labels = labels
110-
old_pts[location] = pts
111-
numb_pts = self._const_pts[location] - len(old_pts[location])
112-
# sample new points
113-
trainer._model.problem.discretise_domain(
114-
numb_pts, "random", locations=[location]
115-
)
116-
117-
else: # if no res greater than average, samples all uniformly
118-
numb_pts = self._const_pts[location]
119-
# sample new points
120-
trainer._model.problem.discretise_domain(
121-
numb_pts, "random", locations=[location]
122-
)
123-
# adding previous population points
124-
trainer._model.problem.add_points(old_pts)
125-
126-
# update dataloader
127-
trainer._create_or_update_loader()
128-
129-
def on_train_start(self, trainer, _):
130-
"""
131-
Callback function called at the start of training.
132-
133-
This method extracts the locations for sampling from the problem
134-
conditions and calculates the total population.
135-
136-
:param trainer: The trainer object managing the training process.
137-
:type trainer: pytorch_lightning.Trainer
138-
:param _: Placeholder argument (not used).
139-
140-
:return: None
141-
:rtype: None
142-
"""
143-
# extract locations for sampling
144-
problem = trainer.solver.problem
145-
locations = []
146-
for condition_name in problem.conditions:
147-
condition = problem.conditions[condition_name]
148-
if hasattr(condition, "location"):
149-
locations.append(condition_name)
150-
self._sampling_locations = locations
151-
152-
# extract total population
153-
const_pts = {} # for each location, store the # of pts to keep constant
154-
for location in self._sampling_locations:
155-
pts = trainer._model.problem.input_pts[location]
156-
const_pts[location] = len(pts)
157-
self._const_pts = const_pts
158-
159-
def on_train_epoch_end(self, trainer, __):
160-
"""
161-
Callback function called at the end of each training epoch.
162-
163-
This method triggers the R3 routine for refinement if the current
164-
epoch is a multiple of `_sample_every`.
165-
166-
:param trainer: The trainer object managing the training process.
167-
:type trainer: pytorch_lightning.Trainer
168-
:param __: Placeholder argument (not used).
169-
170-
:return: None
171-
:rtype: None
172-
"""
173-
if trainer.current_epoch % self._sample_every == 0:
174-
self._r3_routine(trainer)
37+
raise NotImplementedError(
38+
"R3Refinement callback is being refactored in the pina "
39+
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
40+
"version. Please use version 0.1 if R3Refinement is required."
41+
)
42+
43+
# super().__init__()
44+
45+
# # sample every
46+
# check_consistency(sample_every, int)
47+
# self._sample_every = sample_every
48+
# self._const_pts = None
49+
50+
# def _compute_residual(self, trainer):
51+
# """
52+
# Computes the residuals for a PINN object.
53+
54+
# :return: the total loss, and pointwise loss.
55+
# :rtype: tuple
56+
# """
57+
58+
# # extract the solver and device from trainer
59+
# solver = trainer.solver
60+
# device = trainer._accelerator_connector._accelerator_flag
61+
# precision = trainer.precision
62+
# if precision == "64-true":
63+
# precision = torch.float64
64+
# elif precision == "32-true":
65+
# precision = torch.float32
66+
# else:
67+
# raise RuntimeError(
68+
# "Currently R3Refinement is only implemented "
69+
# "for precision '32-true' and '64-true', set "
70+
# "Trainer precision to match one of the "
71+
# "available precisions."
72+
# )
73+
74+
# # compute residual
75+
# res_loss = {}
76+
# tot_loss = []
77+
# for location in self._sampling_locations:
78+
# condition = solver.problem.conditions[location]
79+
# pts = solver.problem.input_pts[location]
80+
# # send points to correct device
81+
# pts = pts.to(device=device, dtype=precision)
82+
# pts = pts.requires_grad_(True)
83+
# pts.retain_grad()
84+
# # PINN loss: equation evaluated only for sampling locations
85+
# target = condition.equation.residual(pts, solver.forward(pts))
86+
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
87+
# tot_loss.append(torch.abs(target))
88+
89+
# print(tot_loss)
90+
91+
# return torch.vstack(tot_loss), res_loss
92+
93+
# def _r3_routine(self, trainer):
94+
# """
95+
# R3 refinement main routine.
96+
97+
# :param Trainer trainer: PINA Trainer.
98+
# """
99+
# # compute residual (all device possible)
100+
# tot_loss, res_loss = self._compute_residual(trainer)
101+
# tot_loss = tot_loss.as_subclass(torch.Tensor)
102+
103+
# # !!!!!! From now everything is performed on CPU !!!!!!
104+
105+
# # average loss
106+
# avg = (tot_loss.mean()).to("cpu")
107+
# old_pts = {} # points to be retained
108+
# for location in self._sampling_locations:
109+
# pts = trainer._model.problem.input_pts[location]
110+
# labels = pts.labels
111+
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
112+
# residuals = res_loss[location].cpu()
113+
# mask = (residuals > avg).flatten()
114+
# if any(mask): # append residuals greater than average
115+
# pts = (pts[mask]).as_subclass(LabelTensor)
116+
# pts.labels = labels
117+
# old_pts[location] = pts
118+
# numb_pts = self._const_pts[location] - len(old_pts[location])
119+
# # sample new points
120+
# trainer._model.problem.discretise_domain(
121+
# numb_pts, "random", locations=[location]
122+
# )
123+
124+
# else: # if no res greater than average, samples all uniformly
125+
# numb_pts = self._const_pts[location]
126+
# # sample new points
127+
# trainer._model.problem.discretise_domain(
128+
# numb_pts, "random", locations=[location]
129+
# )
130+
# # adding previous population points
131+
# trainer._model.problem.add_points(old_pts)
132+
133+
# # update dataloader
134+
# trainer._create_or_update_loader()
135+
136+
# def on_train_start(self, trainer, _):
137+
# """
138+
# Callback function called at the start of training.
139+
140+
# This method extracts the locations for sampling from the problem
141+
# conditions and calculates the total population.
142+
143+
# :param trainer: The trainer object managing the training process.
144+
# :type trainer: pytorch_lightning.Trainer
145+
# :param _: Placeholder argument (not used).
146+
147+
# :return: None
148+
# :rtype: None
149+
# """
150+
# # extract locations for sampling
151+
# problem = trainer.solver.problem
152+
# locations = []
153+
# for condition_name in problem.conditions:
154+
# condition = problem.conditions[condition_name]
155+
# if hasattr(condition, "location"):
156+
# locations.append(condition_name)
157+
# self._sampling_locations = locations
158+
159+
# # extract total population
160+
# const_pts = {} # for each location, store the pts to keep constant
161+
# for location in self._sampling_locations:
162+
# pts = trainer._model.problem.input_pts[location]
163+
# const_pts[location] = len(pts)
164+
# self._const_pts = const_pts
165+
166+
# def on_train_epoch_end(self, trainer, __):
167+
# """
168+
# Callback function called at the end of each training epoch.
169+
170+
# This method triggers the R3 routine for refinement if the current
171+
# epoch is a multiple of `_sample_every`.
172+
173+
# :param trainer: The trainer object managing the training process.
174+
# :type trainer: pytorch_lightning.Trainer
175+
# :param __: Placeholder argument (not used).
176+
177+
# :return: None
178+
# :rtype: None
179+
# """
180+
# if trainer.current_epoch % self._sample_every == 0:
181+
# self._r3_routine(trainer)

pina/callback/linear_weight_update_callback.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ def __init__(
3737
check_consistency(self.initial_value, (float, int), subclass=False)
3838
check_consistency(self.target_value, (float, int), subclass=False)
3939

40-
def on_train_start(self, trainer, solver):
40+
def on_train_start(self, trainer, pl_module):
4141
"""
4242
Initialize the weight of the condition to the specified `initial_value`.
4343
44-
:param Trainer trainer: a pina:class:`Trainer` instance.
45-
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
44+
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
45+
:param SolverInterface pl_module: A
46+
:class:`~pina.solver.solver.SolverInterface` instance.
4647
"""
4748
# Check that the target epoch is valid
4849
if not 0 < self.target_epoch <= trainer.max_epochs:
@@ -52,7 +53,7 @@ def on_train_start(self, trainer, solver):
5253
)
5354

5455
# Check that the condition is a problem condition
55-
if self.condition_name not in solver.problem.conditions:
56+
if self.condition_name not in pl_module.problem.conditions:
5657
raise ValueError(
5758
f"`{self.condition_name}` must be a problem condition."
5859
)
@@ -66,20 +67,21 @@ def on_train_start(self, trainer, solver):
6667
)
6768

6869
# Check that the weighting schema is ScalarWeighting
69-
if not isinstance(solver.weighting, ScalarWeighting):
70+
if not isinstance(pl_module.weighting, ScalarWeighting):
7071
raise ValueError("The weighting schema must be ScalarWeighting.")
7172

7273
# Initialize the weight of the condition
73-
solver.weighting.weights[self.condition_name] = self.initial_value
74+
pl_module.weighting.weights[self.condition_name] = self.initial_value
7475

75-
def on_train_epoch_start(self, trainer, solver):
76+
def on_train_epoch_start(self, trainer, pl_module):
7677
"""
7778
Adjust at each epoch the weight of the condition.
7879
79-
:param Trainer trainer: a pina:class:`Trainer` instance.
80-
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
80+
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
81+
:param SolverInterface pl_module: A
82+
:class:`~pina.solver.solver.SolverInterface` instance.
8183
"""
8284
if 0 < trainer.current_epoch <= self.target_epoch:
83-
solver.weighting.weights[self.condition_name] += (
85+
pl_module.weighting.weights[self.condition_name] += (
8486
self.target_value - self.initial_value
8587
) / (self.target_epoch - 1)

0 commit comments

Comments
 (0)