@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
1515 :class:`SelfAdaptivePINN` solver.
1616 """
1717
18- def __init__ (self , func ):
18+ def __init__ (self , func , num_points ):
1919 """
2020 Initialization of the :class:`Weights` class.
2121
2222 :param torch.nn.Module func: the mask model.
23+ :param int num_points: the number of input points.
2324 """
2425 super ().__init__ ()
26+
27+ # Check consistency
2528 check_consistency (func , torch .nn .Module )
26- self .sa_weights = torch .nn .Parameter (torch .Tensor ())
29+
30+ # Initialize the weights as a learnable parameter
31+ self .sa_weights = torch .nn .Parameter (torch .zeros (num_points , 1 ))
2732 self .func = func
2833
2934 def forward (self ):
@@ -140,73 +145,163 @@ def __init__(
140145 If ``None``, the :class:`torch.nn.MSELoss` loss is used.
141146 Default is `None`.
142147 """
143- # check consistency weitghs_function
148+ # Check consistency
144149 check_consistency (weight_function , torch .nn .Module )
145150
146- # create models for weights
147- weights_dict = {}
148- for condition_name in problem .conditions :
149- weights_dict [ condition_name ] = Weights (weight_function )
150- weights_dict = torch .nn .ModuleDict (weights_dict )
151+ # Define a ModuleDict for the weights
152+ weights = {}
153+ for cond , data in problem .input_pts . items () :
154+ weights [ cond ] = Weights (func = weight_function , num_points = len ( data ) )
155+ weights = torch .nn .ModuleDict (weights )
151156
152157 super ().__init__ (
153- models = [model , weights_dict ],
158+ models = [model , weights ],
154159 problem = problem ,
155160 optimizers = [optimizer_model , optimizer_weights ],
156161 schedulers = [scheduler_model , scheduler_weights ],
157162 weighting = weighting ,
158163 loss = loss ,
159164 )
160165
161- self . _vectorial_loss = deepcopy ( self . loss )
162- self ._vectorial_loss . reduction = "none"
166+ # Extract the reduction method from the loss function
167+ self ._reduction = self . _loss_fn . reduction
163168
164- def forward (self , x ):
165- """
166- Forward pass.
169+ # Set the loss function to return non-aggregated losses
170+ self ._loss_fn = type (self ._loss_fn )(reduction = "none" )
167171
168- :param LabelTensor x: Input tensor.
169- :return: The output of the neural network.
170- :rtype: LabelTensor
172+ def training_step (self , batch , batch_idx , ** kwargs ):
171173 """
172- return self .model (x )
173-
174- def training_step (self , batch ):
175- """
176- Solver training step, overridden to perform manual optimization.
174+ Solver training step. It computes the optimization cycle and aggregates
175+ the losses using the ``weighting`` attribute.
177176
178177 :param list[tuple[str, dict]] batch: A batch of data. Each element is a
179178 tuple containing a condition name and a dictionary of points.
180- :return: The aggregated loss.
181- :rtype: LabelTensor
179+ :param int batch_idx: The index of the current batch.
180+ :param dict kwargs: Additional keyword arguments passed to
181+ ``optimization_cycle``.
182+ :return: The loss of the training step.
183+ :rtype: torch.Tensor
182184 """
183185 # Weights optimization
184186 self .optimizer_weights .instance .zero_grad ()
185- loss = super ().training_step (batch )
187+ loss = self ._optimization_cycle (
188+ batch = batch , batch_idx = batch_idx , ** kwargs
189+ )
186190 self .manual_backward (- loss )
187191 self .optimizer_weights .instance .step ()
188192 self .scheduler_weights .instance .step ()
189193
190194 # Model optimization
191195 self .optimizer_model .instance .zero_grad ()
192- loss = super ().training_step (batch )
196+ loss = self ._optimization_cycle (
197+ batch = batch , batch_idx = batch_idx , ** kwargs
198+ )
193199 self .manual_backward (loss )
194200 self .optimizer_model .instance .step ()
195201 self .scheduler_model .instance .step ()
196202
203+ # Log the loss
204+ self .store_log ("train_loss" , loss , self .get_batch_size (batch ))
205+
206+ return loss
207+
208+ @torch .set_grad_enabled (True )
209+ def validation_step (self , batch , ** kwargs ):
210+ """
211+ The validation step for the Self-Adaptive PINN solver. It returns the
212+ average residual computed with the ``loss`` function not aggregated.
213+
214+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
215+ tuple containing a condition name and a dictionary of points.
216+ :param dict kwargs: Additional keyword arguments passed to
217+ ``optimization_cycle``.
218+ :return: The loss of the validation step.
219+ :rtype: torch.Tensor
220+ """
221+ losses = self .optimization_cycle (batch = batch , ** kwargs )
222+
223+ # Aggregate losses for each condition
224+ for cond , loss in losses .items ():
225+ losses [cond ] = self ._apply_reduction (loss = losses [cond ])
226+
227+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
228+ self .store_log ("val_loss" , loss , self .get_batch_size (batch ))
197229 return loss
198230
231+ @torch .set_grad_enabled (True )
232+ def test_step (self , batch , ** kwargs ):
233+ """
234+ The test step for the Self-Adaptive PINN solver. It returns the average
235+ residual computed with the ``loss`` function not aggregated.
236+
237+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
238+ tuple containing a condition name and a dictionary of points.
239+ :param dict kwargs: Additional keyword arguments passed to
240+ ``optimization_cycle``.
241+ :return: The loss of the test step.
242+ :rtype: torch.Tensor
243+ """
244+ losses = self .optimization_cycle (batch = batch , ** kwargs )
245+
246+ # Aggregate losses for each condition
247+ for cond , loss in losses .items ():
248+ losses [cond ] = self ._apply_reduction (loss = losses [cond ])
249+
250+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
251+ self .store_log ("test_loss" , loss , self .get_batch_size (batch ))
252+ return loss
253+
254+ def loss_phys (self , samples , equation ):
255+ """
256+ Computes the physics loss for the physics-informed solver based on the
257+ provided samples and equation.
258+
259+ :param LabelTensor samples: The samples to evaluate the physics loss.
260+ :param EquationInterface equation: The governing equation.
261+ :return: The computed physics loss.
262+ :rtype: LabelTensor
263+ """
264+ residuals = self .compute_residual (samples , equation )
265+ return self ._loss_fn (residuals , torch .zeros_like (residuals ))
266+
267+ def loss_data (self , input , target ):
268+ """
269+ Compute the data loss for the Self-Adaptive PINN solver by evaluating
270+ the loss between the network's output and the true solution. This method
271+ should not be overridden, if not intentionally.
272+
273+ :param input: The input to the neural network.
274+ :type input: LabelTensor | torch.Tensor
275+ :param target: The target to compare with the network's output.
276+ :type target: LabelTensor | torch.Tensor
277+ :return: The supervised loss, averaged over the number of observations.
278+ :rtype: LabelTensor | torch.Tensor
279+ """
280+ return self ._loss_fn (self .forward (input ), target )
281+
282+ def forward (self , x ):
283+ """
284+ Forward pass.
285+
286+ :param x: Input tensor.
287+ :type x: torch.Tensor | LabelTensor
288+ :return: The output of the neural network.
289+ :rtype: torch.Tensor | LabelTensor
290+ """
291+ return self .model (x )
292+
199293 def configure_optimizers (self ):
200294 """
201295 Optimizer configuration.
202296
203297 :return: The optimizers and the schedulers
204298 :rtype: tuple[list[Optimizer], list[Scheduler]]
205299 """
206- # If the problem is an InverseProblem, add the unknown parameters
207- # to the parameters to be optimized
300+ # Hook the optimizers to the models
208301 self .optimizer_model .hook (self .model .parameters ())
209- self .optimizer_weights .hook (self .weights_dict .parameters ())
302+ self .optimizer_weights .hook (self .weights .parameters ())
303+
304+ # Add unknown parameters to optimization list in case of InverseProblem
210305 if isinstance (self .problem , InverseProblem ):
211306 self .optimizer_model .instance .add_param_group (
212307 {
@@ -216,110 +311,88 @@ def configure_optimizers(self):
216311 ]
217312 }
218313 )
314+
315+ # Hook the schedulers to the optimizers
219316 self .scheduler_model .hook (self .optimizer_model )
220317 self .scheduler_weights .hook (self .optimizer_weights )
318+
221319 return (
222320 [self .optimizer_model .instance , self .optimizer_weights .instance ],
223321 [self .scheduler_model .instance , self .scheduler_weights .instance ],
224322 )
225323
226- def on_train_start (self ):
324+ def _optimization_cycle (self , batch , batch_idx , ** kwargs ):
227325 """
228- This method is called at the start of the training process to set the
229- self-adaptive weights as parameters of the mask model.
230-
231- :raises NotImplementedError: If the batch size is not ``None``.
232- """
233- if self .trainer .batch_size is not None :
234- raise NotImplementedError (
235- "SelfAdaptivePINN only works with full "
236- "batch size, set batch_size=None inside "
237- "the Trainer to use the solver."
238- )
239- device = torch .device (
240- self .trainer ._accelerator_connector ._accelerator_flag
241- )
326+ Aggregate the loss for each condition in the batch.
242327
243- # Initialize the self adaptive weights only for training points
244- for (
245- condition_name ,
246- tensor ,
247- ) in self .trainer .data_module .train_dataset .input .items ():
248- self .weights_dict [condition_name ].sa_weights .data = torch .rand (
249- (tensor .shape [0 ], 1 ), device = device
328+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
329+ tuple containing a condition name and a dictionary of points.
330+ :param int batch_idx: The index of the current batch.
331+ :param dict kwargs: Additional keyword arguments passed to
332+ ``optimization_cycle``.
333+ :return: The losses computed for all conditions in the batch, casted
334+ to a subclass of :class:`torch.Tensor`. It should return a dict
335+ containing the condition name and the associated scalar loss.
336+ :rtype: dict
337+ """
338+ # Compute non-aggregated residuals
339+ residuals = self .optimization_cycle (batch )
340+
341+ # Compute losses
342+ losses = {}
343+ for cond , res in residuals .items ():
344+
345+ weight_tensor = self .weights [cond ]()
346+
347+ # Get the correct indices for the weights. Modulus is used according
348+ # to the number of points in the condition, as in the PinaDataset.
349+ len_res = len (res )
350+ idx = torch .arange (
351+ batch_idx * len_res ,
352+ (batch_idx + 1 ) * len_res ,
353+ device = res .device ,
354+ ) % len (self .problem .input_pts [cond ])
355+
356+ # Apply the weights to the residuals
357+ losses [cond ] = self ._apply_reduction (
358+ loss = (res * weight_tensor [idx ])
250359 )
251- return super ().on_train_start ()
252360
253- def on_load_checkpoint (self , checkpoint ):
254- """
255- Override of the Pytorch Lightning ``on_load_checkpoint`` method to
256- handle checkpoints for Self-Adaptive Weights. This method should not be
257- overridden, if not intentionally.
258-
259- :param dict checkpoint: Pytorch Lightning checkpoint dict.
260- """
261- # First initialize self-adaptive weights with correct shape,
262- # then load the values from the checkpoint.
263- for condition_name , _ in self .problem .input_pts .items ():
264- shape = checkpoint ["state_dict" ][
265- f"_pina_models.1.{ condition_name } .sa_weights"
266- ].shape
267- self .weights_dict [condition_name ].sa_weights .data = torch .rand (
268- shape
361+ # Store log
362+ self .store_log (
363+ f"{ cond } _loss" , losses [cond ].item (), self .get_batch_size (batch )
269364 )
270- return super ().on_load_checkpoint (checkpoint )
271365
272- def loss_phys (self , samples , equation ):
273- """
274- Computes the physics loss for the physics-informed solver based on the
275- provided samples and equation.
366+ # Clamp unknown parameters in InverseProblem (if needed)
367+ self ._clamp_params ()
276368
277- :param LabelTensor samples: The samples to evaluate the physics loss.
278- :param EquationInterface equation: The governing equation.
279- :return: The computed physics loss.
280- :rtype: LabelTensor
281- """
282- residual = self .compute_residual (samples , equation )
283- weights = self .weights_dict [self .current_condition_name ].forward ()
284- loss_value = self ._vectorial_loss (
285- torch .zeros_like (residual , requires_grad = True ), residual
286- )
287- return self ._vect_to_scalar (weights * loss_value )
288-
289- def loss_data (self , input , target ):
290- """
291- Compute the data loss for the PINN solver by evaluating the loss
292- between the network's output and the true solution. This method should
293- not be overridden, if not intentionally.
369+ # Aggregate
370+ loss = self .weighting .aggregate (losses ).as_subclass (torch .Tensor )
294371
295- :param input: The input to the neural network.
296- :type input: LabelTensor
297- :param target: The target to compare with the network's output.
298- :type target: LabelTensor
299- :return: The supervised loss, averaged over the number of observations.
300- :rtype: LabelTensor
301- """
302- return self ._loss_fn (self .forward (input ), target )
372+ return loss
303373
304- def _vect_to_scalar (self , loss_value ):
374+ def _apply_reduction (self , loss ):
305375 """
306- Computation of the scalar loss.
376+ Apply the specified reduction to the loss. The reduction is deferred
377+ until the end of the optimization cycle to allow self-adaptive weights
378+ to be applied to each point beforehand.
307379
308- :param LabelTensor loss_value: the tensor of pointwise losses.
309- :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
310- :return: The computed scalar loss.
311- :rtype: LabelTensor
312- """
313- if self .loss .reduction == "mean" :
314- ret = torch .mean (loss_value )
315- elif self .loss .reduction == "sum" :
316- ret = torch .sum (loss_value )
317- else :
318- raise RuntimeError (
319- f"Invalid reduction, got { self .loss .reduction } "
320- "but expected mean or sum."
321- )
322- return ret
380+ :param torch.Tensor loss: The loss tensor to be reduced.
381+ :return: The reduced loss tensor.
382+ :rtype: torch.Tensor
383+ :raises ValueError: If the reduction method is neither "mean" nor "sum".
384+ """
385+ # Apply the specified reduction method
386+ if self ._reduction == "mean" :
387+ return loss .mean ()
388+ if self ._reduction == "sum" :
389+ return loss .sum ()
390+
391+ # Raise an error if the reduction method is not recognized
392+ raise ValueError (
393+ f"Unknown reduction: { self ._reduction } ."
394+ " Supported reductions are 'mean' and 'sum'."
395+ )
323396
324397 @property
325398 def model (self ):
@@ -332,7 +405,7 @@ def model(self):
332405 return self .models [0 ]
333406
334407 @property
335- def weights_dict (self ):
408+ def weights (self ):
336409 """
337410 The self-adaptive weights.
338411
0 commit comments