Skip to content

Commit 6d1d4ef

Browse files
add batching support for self-adaptive pinns
1 parent 1ed1491 commit 6d1d4ef

File tree

2 files changed

+207
-139
lines changed

2 files changed

+207
-139
lines changed

pina/solver/physics_informed_solver/self_adaptive_pinn.py

Lines changed: 189 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
1515
:class:`SelfAdaptivePINN` solver.
1616
"""
1717

18-
def __init__(self, func):
18+
def __init__(self, func, num_points):
1919
"""
2020
Initialization of the :class:`Weights` class.
2121
2222
:param torch.nn.Module func: the mask model.
23+
:param int num_points: the number of input points.
2324
"""
2425
super().__init__()
26+
27+
# Check consistency
2528
check_consistency(func, torch.nn.Module)
26-
self.sa_weights = torch.nn.Parameter(torch.Tensor())
29+
30+
# Initialize the weights as a learnable parameter
31+
self.sa_weights = torch.nn.Parameter(torch.zeros(num_points, 1))
2732
self.func = func
2833

2934
def forward(self):
@@ -140,73 +145,163 @@ def __init__(
140145
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
141146
Default is `None`.
142147
"""
143-
# check consistency weitghs_function
148+
# Check consistency
144149
check_consistency(weight_function, torch.nn.Module)
145150

146-
# create models for weights
147-
weights_dict = {}
148-
for condition_name in problem.conditions:
149-
weights_dict[condition_name] = Weights(weight_function)
150-
weights_dict = torch.nn.ModuleDict(weights_dict)
151+
# Define a ModuleDict for the weights
152+
weights = {}
153+
for cond, data in problem.input_pts.items():
154+
weights[cond] = Weights(func=weight_function, num_points=len(data))
155+
weights = torch.nn.ModuleDict(weights)
151156

152157
super().__init__(
153-
models=[model, weights_dict],
158+
models=[model, weights],
154159
problem=problem,
155160
optimizers=[optimizer_model, optimizer_weights],
156161
schedulers=[scheduler_model, scheduler_weights],
157162
weighting=weighting,
158163
loss=loss,
159164
)
160165

161-
self._vectorial_loss = deepcopy(self.loss)
162-
self._vectorial_loss.reduction = "none"
166+
# Extract the reduction method from the loss function
167+
self._reduction = self._loss_fn.reduction
163168

164-
def forward(self, x):
165-
"""
166-
Forward pass.
169+
# Set the loss function to return non-aggregated losses
170+
self._loss_fn = type(self._loss_fn)(reduction="none")
167171

168-
:param LabelTensor x: Input tensor.
169-
:return: The output of the neural network.
170-
:rtype: LabelTensor
172+
def training_step(self, batch, batch_idx, **kwargs):
171173
"""
172-
return self.model(x)
173-
174-
def training_step(self, batch):
175-
"""
176-
Solver training step, overridden to perform manual optimization.
174+
Solver training step. It computes the optimization cycle and aggregates
175+
the losses using the ``weighting`` attribute.
177176
178177
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
179178
tuple containing a condition name and a dictionary of points.
180-
:return: The aggregated loss.
181-
:rtype: LabelTensor
179+
:param int batch_idx: The index of the current batch.
180+
:param dict kwargs: Additional keyword arguments passed to
181+
``optimization_cycle``.
182+
:return: The loss of the training step.
183+
:rtype: torch.Tensor
182184
"""
183185
# Weights optimization
184186
self.optimizer_weights.instance.zero_grad()
185-
loss = super().training_step(batch)
187+
loss = self._optimization_cycle(
188+
batch=batch, batch_idx=batch_idx, **kwargs
189+
)
186190
self.manual_backward(-loss)
187191
self.optimizer_weights.instance.step()
188192
self.scheduler_weights.instance.step()
189193

190194
# Model optimization
191195
self.optimizer_model.instance.zero_grad()
192-
loss = super().training_step(batch)
196+
loss = self._optimization_cycle(
197+
batch=batch, batch_idx=batch_idx, **kwargs
198+
)
193199
self.manual_backward(loss)
194200
self.optimizer_model.instance.step()
195201
self.scheduler_model.instance.step()
196202

203+
# Log the loss
204+
self.store_log("train_loss", loss, self.get_batch_size(batch))
205+
206+
return loss
207+
208+
@torch.set_grad_enabled(True)
209+
def validation_step(self, batch, **kwargs):
210+
"""
211+
The validation step for the Self-Adaptive PINN solver. It returns the
212+
average residual computed with the ``loss`` function not aggregated.
213+
214+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
215+
tuple containing a condition name and a dictionary of points.
216+
:param dict kwargs: Additional keyword arguments passed to
217+
``optimization_cycle``.
218+
:return: The loss of the validation step.
219+
:rtype: torch.Tensor
220+
"""
221+
losses = self.optimization_cycle(batch=batch, **kwargs)
222+
223+
# Aggregate losses for each condition
224+
for cond, loss in losses.items():
225+
losses[cond] = self._apply_reduction(loss=losses[cond])
226+
227+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
228+
self.store_log("val_loss", loss, self.get_batch_size(batch))
197229
return loss
198230

231+
@torch.set_grad_enabled(True)
232+
def test_step(self, batch, **kwargs):
233+
"""
234+
The test step for the Self-Adaptive PINN solver. It returns the average
235+
residual computed with the ``loss`` function not aggregated.
236+
237+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
238+
tuple containing a condition name and a dictionary of points.
239+
:param dict kwargs: Additional keyword arguments passed to
240+
``optimization_cycle``.
241+
:return: The loss of the test step.
242+
:rtype: torch.Tensor
243+
"""
244+
losses = self.optimization_cycle(batch=batch, **kwargs)
245+
246+
# Aggregate losses for each condition
247+
for cond, loss in losses.items():
248+
losses[cond] = self._apply_reduction(loss=losses[cond])
249+
250+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
251+
self.store_log("test_loss", loss, self.get_batch_size(batch))
252+
return loss
253+
254+
def loss_phys(self, samples, equation):
255+
"""
256+
Computes the physics loss for the physics-informed solver based on the
257+
provided samples and equation.
258+
259+
:param LabelTensor samples: The samples to evaluate the physics loss.
260+
:param EquationInterface equation: The governing equation.
261+
:return: The computed physics loss.
262+
:rtype: LabelTensor
263+
"""
264+
residuals = self.compute_residual(samples, equation)
265+
return self._loss_fn(residuals, torch.zeros_like(residuals))
266+
267+
def loss_data(self, input, target):
268+
"""
269+
Compute the data loss for the Self-Adaptive PINN solver by evaluating
270+
the loss between the network's output and the true solution. This method
271+
should not be overridden, if not intentionally.
272+
273+
:param input: The input to the neural network.
274+
:type input: LabelTensor | torch.Tensor
275+
:param target: The target to compare with the network's output.
276+
:type target: LabelTensor | torch.Tensor
277+
:return: The supervised loss, averaged over the number of observations.
278+
:rtype: LabelTensor | torch.Tensor
279+
"""
280+
return self._loss_fn(self.forward(input), target)
281+
282+
def forward(self, x):
283+
"""
284+
Forward pass.
285+
286+
:param x: Input tensor.
287+
:type x: torch.Tensor | LabelTensor
288+
:return: The output of the neural network.
289+
:rtype: torch.Tensor | LabelTensor
290+
"""
291+
return self.model(x)
292+
199293
def configure_optimizers(self):
200294
"""
201295
Optimizer configuration.
202296
203297
:return: The optimizers and the schedulers
204298
:rtype: tuple[list[Optimizer], list[Scheduler]]
205299
"""
206-
# If the problem is an InverseProblem, add the unknown parameters
207-
# to the parameters to be optimized
300+
# Hook the optimizers to the models
208301
self.optimizer_model.hook(self.model.parameters())
209-
self.optimizer_weights.hook(self.weights_dict.parameters())
302+
self.optimizer_weights.hook(self.weights.parameters())
303+
304+
# Add unknown parameters to optimization list in case of InverseProblem
210305
if isinstance(self.problem, InverseProblem):
211306
self.optimizer_model.instance.add_param_group(
212307
{
@@ -216,110 +311,88 @@ def configure_optimizers(self):
216311
]
217312
}
218313
)
314+
315+
# Hook the schedulers to the optimizers
219316
self.scheduler_model.hook(self.optimizer_model)
220317
self.scheduler_weights.hook(self.optimizer_weights)
318+
221319
return (
222320
[self.optimizer_model.instance, self.optimizer_weights.instance],
223321
[self.scheduler_model.instance, self.scheduler_weights.instance],
224322
)
225323

226-
def on_train_start(self):
324+
def _optimization_cycle(self, batch, batch_idx, **kwargs):
227325
"""
228-
This method is called at the start of the training process to set the
229-
self-adaptive weights as parameters of the mask model.
230-
231-
:raises NotImplementedError: If the batch size is not ``None``.
232-
"""
233-
if self.trainer.batch_size is not None:
234-
raise NotImplementedError(
235-
"SelfAdaptivePINN only works with full "
236-
"batch size, set batch_size=None inside "
237-
"the Trainer to use the solver."
238-
)
239-
device = torch.device(
240-
self.trainer._accelerator_connector._accelerator_flag
241-
)
326+
Aggregate the loss for each condition in the batch.
242327
243-
# Initialize the self adaptive weights only for training points
244-
for (
245-
condition_name,
246-
tensor,
247-
) in self.trainer.data_module.train_dataset.input.items():
248-
self.weights_dict[condition_name].sa_weights.data = torch.rand(
249-
(tensor.shape[0], 1), device=device
328+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
329+
tuple containing a condition name and a dictionary of points.
330+
:param int batch_idx: The index of the current batch.
331+
:param dict kwargs: Additional keyword arguments passed to
332+
``optimization_cycle``.
333+
:return: The losses computed for all conditions in the batch, casted
334+
to a subclass of :class:`torch.Tensor`. It should return a dict
335+
containing the condition name and the associated scalar loss.
336+
:rtype: dict
337+
"""
338+
# Compute non-aggregated residuals
339+
residuals = self.optimization_cycle(batch)
340+
341+
# Compute losses
342+
losses = {}
343+
for cond, res in residuals.items():
344+
345+
weight_tensor = self.weights[cond]()
346+
347+
# Get the correct indices for the weights. Modulus is used according
348+
# to the number of points in the condition, as in the PinaDataset.
349+
len_res = len(res)
350+
idx = torch.arange(
351+
batch_idx * len_res,
352+
(batch_idx + 1) * len_res,
353+
device=res.device,
354+
) % len(self.problem.input_pts[cond])
355+
356+
# Apply the weights to the residuals
357+
losses[cond] = self._apply_reduction(
358+
loss=(res * weight_tensor[idx])
250359
)
251-
return super().on_train_start()
252360

253-
def on_load_checkpoint(self, checkpoint):
254-
"""
255-
Override of the Pytorch Lightning ``on_load_checkpoint`` method to
256-
handle checkpoints for Self-Adaptive Weights. This method should not be
257-
overridden, if not intentionally.
258-
259-
:param dict checkpoint: Pytorch Lightning checkpoint dict.
260-
"""
261-
# First initialize self-adaptive weights with correct shape,
262-
# then load the values from the checkpoint.
263-
for condition_name, _ in self.problem.input_pts.items():
264-
shape = checkpoint["state_dict"][
265-
f"_pina_models.1.{condition_name}.sa_weights"
266-
].shape
267-
self.weights_dict[condition_name].sa_weights.data = torch.rand(
268-
shape
361+
# Store log
362+
self.store_log(
363+
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
269364
)
270-
return super().on_load_checkpoint(checkpoint)
271365

272-
def loss_phys(self, samples, equation):
273-
"""
274-
Computes the physics loss for the physics-informed solver based on the
275-
provided samples and equation.
366+
# Clamp unknown parameters in InverseProblem (if needed)
367+
self._clamp_params()
276368

277-
:param LabelTensor samples: The samples to evaluate the physics loss.
278-
:param EquationInterface equation: The governing equation.
279-
:return: The computed physics loss.
280-
:rtype: LabelTensor
281-
"""
282-
residual = self.compute_residual(samples, equation)
283-
weights = self.weights_dict[self.current_condition_name].forward()
284-
loss_value = self._vectorial_loss(
285-
torch.zeros_like(residual, requires_grad=True), residual
286-
)
287-
return self._vect_to_scalar(weights * loss_value)
288-
289-
def loss_data(self, input, target):
290-
"""
291-
Compute the data loss for the PINN solver by evaluating the loss
292-
between the network's output and the true solution. This method should
293-
not be overridden, if not intentionally.
369+
# Aggregate
370+
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
294371

295-
:param input: The input to the neural network.
296-
:type input: LabelTensor
297-
:param target: The target to compare with the network's output.
298-
:type target: LabelTensor
299-
:return: The supervised loss, averaged over the number of observations.
300-
:rtype: LabelTensor
301-
"""
302-
return self._loss_fn(self.forward(input), target)
372+
return loss
303373

304-
def _vect_to_scalar(self, loss_value):
374+
def _apply_reduction(self, loss):
305375
"""
306-
Computation of the scalar loss.
376+
Apply the specified reduction to the loss. The reduction is deferred
377+
until the end of the optimization cycle to allow self-adaptive weights
378+
to be applied to each point beforehand.
307379
308-
:param LabelTensor loss_value: the tensor of pointwise losses.
309-
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
310-
:return: The computed scalar loss.
311-
:rtype: LabelTensor
312-
"""
313-
if self.loss.reduction == "mean":
314-
ret = torch.mean(loss_value)
315-
elif self.loss.reduction == "sum":
316-
ret = torch.sum(loss_value)
317-
else:
318-
raise RuntimeError(
319-
f"Invalid reduction, got {self.loss.reduction} "
320-
"but expected mean or sum."
321-
)
322-
return ret
380+
:param torch.Tensor loss: The loss tensor to be reduced.
381+
:return: The reduced loss tensor.
382+
:rtype: torch.Tensor
383+
:raises ValueError: If the reduction method is neither "mean" nor "sum".
384+
"""
385+
# Apply the specified reduction method
386+
if self._reduction == "mean":
387+
return loss.mean()
388+
if self._reduction == "sum":
389+
return loss.sum()
390+
391+
# Raise an error if the reduction method is not recognized
392+
raise ValueError(
393+
f"Unknown reduction: {self._reduction}."
394+
" Supported reductions are 'mean' and 'sum'."
395+
)
323396

324397
@property
325398
def model(self):
@@ -332,7 +405,7 @@ def model(self):
332405
return self.models[0]
333406

334407
@property
335-
def weights_dict(self):
408+
def weights(self):
336409
"""
337410
The self-adaptive weights.
338411

0 commit comments

Comments
 (0)