Skip to content

Commit de47d69

Browse files
GiovanniCanalidario-coscia
authored andcommitted
batching for rbapinns
1 parent 3778ef7 commit de47d69

File tree

2 files changed

+206
-82
lines changed

2 files changed

+206
-82
lines changed

pina/solver/physics_informed_solver/rba_pinn.py

Lines changed: 179 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Module for the Residual-Based Attention PINN solver."""
22

3-
from copy import deepcopy
43
import torch
54

65
from .pinn import PINN
@@ -98,6 +97,8 @@ def __init__(
9897
:param float gamma: The decay parameter in the update of the weights
9998
of the residuals. Must be between ``0`` and ``1``.
10099
Default is ``0.999``.
100+
:raises: ValueError if `gamma` is not in the range (0, 1).
101+
:raises: ValueError if `eta` is not greater than 0.
101102
"""
102103
super().__init__(
103104
model=model,
@@ -111,78 +112,201 @@ def __init__(
111112
# check consistency
112113
check_consistency(eta, (float, int))
113114
check_consistency(gamma, float)
114-
assert (
115-
0 < gamma < 1
116-
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}"
115+
116+
# Validate range for gamma
117+
if not 0 < gamma < 1:
118+
raise ValueError(
119+
f"Invalid range: expected 0 < gamma < 1, but got {gamma}"
120+
)
121+
122+
# Validate range for eta
123+
if eta <= 0:
124+
raise ValueError(f"Invalid range: expected eta > 0, but got {eta}")
125+
126+
# Initialize parameters
117127
self.eta = eta
118128
self.gamma = gamma
119129

120-
# initialize weights
130+
# Initialize the weight of each point to 0
121131
self.weights = {}
122-
for condition_name in problem.conditions:
123-
self.weights[condition_name] = 0
132+
for cond, data in self.problem.input_pts.items():
133+
buffer_tensor = torch.zeros((len(data), 1), device=self.device)
134+
self.register_buffer(f"weight_{cond}", buffer_tensor)
135+
self.weights[cond] = getattr(self, f"weight_{cond}")
136+
137+
# Extract the reduction method from the loss function
138+
self._reduction = self._loss_fn.reduction
124139

125-
# define vectorial loss
126-
self._vectorial_loss = deepcopy(self.loss)
127-
self._vectorial_loss.reduction = "none"
140+
# Set the loss function to return non-aggregated losses
141+
self._loss_fn = type(self._loss_fn)(reduction="none")
128142

129-
# for now RBAPINN is implemented only for batch_size = None
130-
def on_train_start(self):
143+
def training_step(self, batch, batch_idx, **kwargs):
131144
"""
132-
Hook method called at the beginning of training.
145+
Solver training step. It computes the optimization cycle and aggregates
146+
the losses using the ``weighting`` attribute.
133147
134-
:raises NotImplementedError: If the batch size is not ``None``.
148+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
149+
tuple containing a condition name and a dictionary of points.
150+
:param int batch_idx: The index of the current batch.
151+
:param dict kwargs: Additional keyword arguments passed to
152+
``optimization_cycle``.
153+
:return: The loss of the training step.
154+
:rtype: torch.Tensor
135155
"""
136-
if self.trainer.batch_size is not None:
137-
raise NotImplementedError(
138-
"RBAPINN only works with full batch "
139-
"size, set batch_size=None inside the "
140-
"Trainer to use the solver."
141-
)
142-
return super().on_train_start()
156+
loss = self._optimization_cycle(
157+
batch=batch, batch_idx=batch_idx, **kwargs
158+
)
159+
self.store_log("train_loss", loss, self.get_batch_size(batch))
160+
return loss
161+
162+
@torch.set_grad_enabled(True)
163+
def validation_step(self, batch, **kwargs):
164+
"""
165+
The validation step for the PINN solver. It returns the average residual
166+
computed with the ``loss`` function not aggregated.
167+
168+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
169+
tuple containing a condition name and a dictionary of points.
170+
:param dict kwargs: Additional keyword arguments passed to
171+
``optimization_cycle``.
172+
:return: The loss of the validation step.
173+
:rtype: torch.Tensor
174+
"""
175+
losses = self.optimization_cycle(batch=batch, **kwargs)
176+
177+
# Aggregate losses for each condition
178+
for cond, loss in losses.items():
179+
losses[cond] = self._apply_reduction(loss=losses[cond])
180+
181+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
182+
self.store_log("val_loss", loss, self.get_batch_size(batch))
183+
return loss
184+
185+
@torch.set_grad_enabled(True)
186+
def test_step(self, batch, **kwargs):
187+
"""
188+
The test step for the PINN solver. It returns the average residual
189+
computed with the ``loss`` function not aggregated.
190+
191+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
192+
tuple containing a condition name and a dictionary of points.
193+
:param dict kwargs: Additional keyword arguments passed to
194+
``optimization_cycle``.
195+
:return: The loss of the test step.
196+
:rtype: torch.Tensor
197+
"""
198+
losses = self.optimization_cycle(batch=batch, **kwargs)
199+
200+
# Aggregate losses for each condition
201+
for cond, loss in losses.items():
202+
losses[cond] = self._apply_reduction(loss=losses[cond])
143203

144-
def _vect_to_scalar(self, loss_value):
204+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
205+
self.store_log("test_loss", loss, self.get_batch_size(batch))
206+
return loss
207+
208+
def _optimization_cycle(self, batch, batch_idx, **kwargs):
145209
"""
146-
Computation of the scalar loss.
210+
Aggregate the loss for each condition in the batch.
147211
148-
:param LabelTensor loss_value: the tensor of pointwise losses.
149-
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
150-
:return: The computed scalar loss.
151-
:rtype: LabelTensor
212+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
213+
tuple containing a condition name and a dictionary of points.
214+
:param int batch_idx: The index of the current batch.
215+
:param dict kwargs: Additional keyword arguments passed to
216+
``optimization_cycle``.
217+
:return: The losses computed for all conditions in the batch, casted
218+
to a subclass of :class:`torch.Tensor`. It should return a dict
219+
containing the condition name and the associated scalar loss.
220+
:rtype: dict
152221
"""
153-
if self.loss.reduction == "mean":
154-
ret = torch.mean(loss_value)
155-
elif self.loss.reduction == "sum":
156-
ret = torch.sum(loss_value)
157-
else:
158-
raise RuntimeError(
159-
f"Invalid reduction, got {self.loss.reduction} "
160-
"but expected mean or sum."
222+
# compute non-aggregated residuals
223+
residuals = self.optimization_cycle(batch)
224+
225+
# update weights based on residuals
226+
self._update_weights(batch, batch_idx, residuals)
227+
228+
# compute losses
229+
losses = {}
230+
for cond, res in residuals.items():
231+
232+
# Get the correct indices for the weights. Modulus is used according
233+
# to the number of points in the condition, as in the PinaDataset.
234+
len_res = len(res)
235+
idx = torch.arange(
236+
batch_idx * len_res,
237+
(batch_idx + 1) * len_res,
238+
device=res.device,
239+
) % len(self.problem.input_pts[cond])
240+
241+
losses[cond] = self._apply_reduction(
242+
loss=(res * self.weights[cond][idx])
161243
)
162-
return ret
163244

164-
def loss_phys(self, samples, equation):
245+
# store log
246+
self.store_log(
247+
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
248+
)
249+
250+
# clamp unknown parameters in InverseProblem (if needed)
251+
self._clamp_params()
252+
253+
# aggregate
254+
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
255+
256+
return loss
257+
258+
def _update_weights(self, batch, batch_idx, residuals):
165259
"""
166-
Computes the physics loss for the physics-informed solver based on the
167-
provided samples and equation.
260+
Update weights based on residuals.
168261
169-
:param LabelTensor samples: The samples to evaluate the physics loss.
170-
:param EquationInterface equation: The governing equation.
171-
:return: The computed physics loss.
172-
:rtype: LabelTensor
262+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
263+
tuple containing a condition name and a dictionary of points.
264+
:param int batch_idx: The index of the current batch.
265+
:param dict residuals: A dictionary containing the residuals for each
266+
condition. The keys are the condition names and the values are the
267+
residuals as tensors.
173268
"""
174-
residual = self.compute_residual(samples=samples, equation=equation)
175-
cond = self.current_condition_name
269+
# Iterate over each condition in the batch
270+
for cond, data in batch:
176271

177-
r_norm = (
178-
self.eta
179-
* torch.abs(residual)
180-
/ (torch.max(torch.abs(residual)) + 1e-12)
181-
)
182-
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
272+
# Compute normalized residuals
273+
res = residuals[cond]
274+
res_abs = res.abs()
275+
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
183276

184-
loss_value = self._vectorial_loss(
185-
torch.zeros_like(residual, requires_grad=True), residual
186-
)
277+
# Get the correct indices for the weights. Modulus is used according
278+
# to the number of points in the condition, as in the PinaDataset.
279+
len_pts = len(data["input"])
280+
idx = torch.arange(
281+
batch_idx * len_pts,
282+
(batch_idx + 1) * len_pts,
283+
device=res.device,
284+
) % len(self.problem.input_pts[cond])
187285

188-
return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value)
286+
# Update weights
287+
weights = self.weights[cond]
288+
update = self.gamma * weights[idx] + r_norm
289+
weights[idx] = update.detach()
290+
291+
def _apply_reduction(self, loss):
292+
"""
293+
Apply the specified reduction to the loss. The reduction is deferred
294+
until the end of the optimization cycle to allow residual-based weights
295+
to be applied to each point beforehand.
296+
297+
:param torch.Tensor loss: The loss tensor to be reduced.
298+
:return: The reduced loss tensor.
299+
:rtype: torch.Tensor
300+
:raises ValueError: If the reduction method is neither "mean" nor "sum".
301+
"""
302+
# Apply the specified reduction method
303+
if self._reduction == "mean":
304+
return loss.mean()
305+
if self._reduction == "sum":
306+
return loss.sum()
307+
308+
# Raise an error if the reduction method is not recognized
309+
raise ValueError(
310+
f"Unknown reduction: {self._reduction}."
311+
" Supported reductions are 'mean' and 'sum'."
312+
)

tests/test_solver/test_rba_pinn.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@
4242
@pytest.mark.parametrize("eta", [1, 0.001])
4343
@pytest.mark.parametrize("gamma", [0.5, 0.9])
4444
def test_constructor(problem, eta, gamma):
45-
with pytest.raises(AssertionError):
46-
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
4745
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)
4846

47+
with pytest.raises(ValueError):
48+
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
49+
50+
with pytest.raises(ValueError):
51+
solver = RBAPINN(model=model, problem=problem, eta=-0.1)
52+
4953
assert solver.accepted_conditions_types == (
5054
InputTargetCondition,
5155
InputEquationCondition,
@@ -54,30 +58,18 @@ def test_constructor(problem, eta, gamma):
5458

5559

5660
@pytest.mark.parametrize("problem", [problem, inverse_problem])
57-
def test_wrong_batch(problem):
58-
with pytest.raises(NotImplementedError):
59-
solver = RBAPINN(model=model, problem=problem)
60-
trainer = Trainer(
61-
solver=solver,
62-
max_epochs=2,
63-
accelerator="cpu",
64-
batch_size=10,
65-
train_size=1.0,
66-
val_size=0.0,
67-
test_size=0.0,
68-
)
69-
trainer.train()
70-
71-
72-
@pytest.mark.parametrize("problem", [problem, inverse_problem])
61+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
7362
@pytest.mark.parametrize("compile", [True, False])
74-
def test_solver_train(problem, compile):
75-
solver = RBAPINN(model=model, problem=problem)
63+
@pytest.mark.parametrize(
64+
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
65+
)
66+
def test_solver_train(problem, batch_size, loss, compile):
67+
solver = RBAPINN(model=model, problem=problem, loss=loss)
7668
trainer = Trainer(
7769
solver=solver,
7870
max_epochs=2,
7971
accelerator="cpu",
80-
batch_size=None,
72+
batch_size=batch_size,
8173
train_size=1.0,
8274
val_size=0.0,
8375
test_size=0.0,
@@ -89,14 +81,18 @@ def test_solver_train(problem, compile):
8981

9082

9183
@pytest.mark.parametrize("problem", [problem, inverse_problem])
84+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
9285
@pytest.mark.parametrize("compile", [True, False])
93-
def test_solver_validation(problem, compile):
94-
solver = RBAPINN(model=model, problem=problem)
86+
@pytest.mark.parametrize(
87+
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
88+
)
89+
def test_solver_validation(problem, batch_size, loss, compile):
90+
solver = RBAPINN(model=model, problem=problem, loss=loss)
9591
trainer = Trainer(
9692
solver=solver,
9793
max_epochs=2,
9894
accelerator="cpu",
99-
batch_size=None,
95+
batch_size=batch_size,
10096
train_size=0.9,
10197
val_size=0.1,
10298
test_size=0.0,
@@ -108,14 +104,18 @@ def test_solver_validation(problem, compile):
108104

109105

110106
@pytest.mark.parametrize("problem", [problem, inverse_problem])
107+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
111108
@pytest.mark.parametrize("compile", [True, False])
112-
def test_solver_test(problem, compile):
113-
solver = RBAPINN(model=model, problem=problem)
109+
@pytest.mark.parametrize(
110+
"loss", [torch.nn.L1Loss(reduction="sum"), torch.nn.MSELoss()]
111+
)
112+
def test_solver_test(problem, batch_size, loss, compile):
113+
solver = RBAPINN(model=model, problem=problem, loss=loss)
114114
trainer = Trainer(
115115
solver=solver,
116116
max_epochs=2,
117117
accelerator="cpu",
118-
batch_size=None,
118+
batch_size=batch_size,
119119
train_size=0.7,
120120
val_size=0.2,
121121
test_size=0.1,

0 commit comments

Comments
 (0)