11"""Module for the Residual-Based Attention PINN solver."""
22
3- from copy import deepcopy
43import torch
54
65from .pinn import PINN
@@ -98,6 +97,8 @@ def __init__(
9897 :param float gamma: The decay parameter in the update of the weights
9998 of the residuals. Must be between ``0`` and ``1``.
10099 Default is ``0.999``.
100+ :raises: ValueError if `gamma` is not in the range (0, 1).
101+ :raises: ValueError if `eta` is not greater than 0.
101102 """
102103 super ().__init__ (
103104 model = model ,
@@ -111,78 +112,201 @@ def __init__(
111112 # check consistency
112113 check_consistency (eta , (float , int ))
113114 check_consistency (gamma , float )
114- assert (
115- 0 < gamma < 1
116- ), f"Invalid range: expected 0 < gamma < 1, got { gamma = } "
115+
116+ # Validate range for gamma
117+ if not 0 < gamma < 1 :
118+ raise ValueError (
119+ f"Invalid range: expected 0 < gamma < 1, but got { gamma } "
120+ )
121+
122+ # Validate range for eta
123+ if eta <= 0 :
124+ raise ValueError (f"Invalid range: expected eta > 0, but got { eta } " )
125+
126+ # Initialize parameters
117127 self .eta = eta
118128 self .gamma = gamma
119129
120- # initialize weights
130+ # Initialize the weight of each point to 0
121131 self .weights = {}
122- for condition_name in problem .conditions :
123- self .weights [condition_name ] = 0
132+ for cond , data in self .problem .input_pts .items ():
133+ buffer_tensor = torch .zeros ((len (data ), 1 ), device = self .device )
134+ self .register_buffer (f"weight_{ cond } " , buffer_tensor )
135+ self .weights [cond ] = getattr (self , f"weight_{ cond } " )
136+
137+ # Extract the reduction method from the loss function
138+ self ._reduction = self ._loss_fn .reduction
124139
125- # define vectorial loss
126- self ._vectorial_loss = deepcopy (self .loss )
127- self ._vectorial_loss .reduction = "none"
140+ # Set the loss function to return non-aggregated losses
141+ self ._loss_fn = type (self ._loss_fn )(reduction = "none" )
128142
129- # for now RBAPINN is implemented only for batch_size = None
130- def on_train_start (self ):
143+ def training_step (self , batch , batch_idx , ** kwargs ):
131144 """
132- Hook method called at the beginning of training.
145+ Solver training step. It computes the optimization cycle and aggregates
146+ the losses using the ``weighting`` attribute.
133147
134- :raises NotImplementedError: If the batch size is not ``None``.
148+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
149+ tuple containing a condition name and a dictionary of points.
150+ :param int batch_idx: The index of the current batch.
151+ :param dict kwargs: Additional keyword arguments passed to
152+ ``optimization_cycle``.
153+ :return: The loss of the training step.
154+ :rtype: torch.Tensor
135155 """
136- if self .trainer .batch_size is not None :
137- raise NotImplementedError (
138- "RBAPINN only works with full batch "
139- "size, set batch_size=None inside the "
140- "Trainer to use the solver."
141- )
142- return super ().on_train_start ()
156+ loss = self ._optimization_cycle (
157+ batch = batch , batch_idx = batch_idx , ** kwargs
158+ )
159+ self .store_log ("train_loss" , loss , self .get_batch_size (batch ))
160+ return loss
161+
162+ @torch .set_grad_enabled (True )
163+ def validation_step (self , batch , ** kwargs ):
164+ """
165+ The validation step for the PINN solver. It returns the average residual
166+ computed with the ``loss`` function not aggregated.
167+
168+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
169+ tuple containing a condition name and a dictionary of points.
170+ :param dict kwargs: Additional keyword arguments passed to
171+ ``optimization_cycle``.
172+ :return: The loss of the validation step.
173+ :rtype: torch.Tensor
174+ """
175+ losses = self .optimization_cycle (batch = batch , ** kwargs )
176+
177+ # Aggregate losses for each condition
178+ for cond , loss in losses .items ():
179+ losses [cond ] = self ._apply_reduction (loss = losses [cond ])
180+
181+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
182+ self .store_log ("val_loss" , loss , self .get_batch_size (batch ))
183+ return loss
184+
185+ @torch .set_grad_enabled (True )
186+ def test_step (self , batch , ** kwargs ):
187+ """
188+ The test step for the PINN solver. It returns the average residual
189+ computed with the ``loss`` function not aggregated.
190+
191+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
192+ tuple containing a condition name and a dictionary of points.
193+ :param dict kwargs: Additional keyword arguments passed to
194+ ``optimization_cycle``.
195+ :return: The loss of the test step.
196+ :rtype: torch.Tensor
197+ """
198+ losses = self .optimization_cycle (batch = batch , ** kwargs )
199+
200+ # Aggregate losses for each condition
201+ for cond , loss in losses .items ():
202+ losses [cond ] = self ._apply_reduction (loss = losses [cond ])
143203
144- def _vect_to_scalar (self , loss_value ):
204+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
205+ self .store_log ("test_loss" , loss , self .get_batch_size (batch ))
206+ return loss
207+
208+ def _optimization_cycle (self , batch , batch_idx , ** kwargs ):
145209 """
146- Computation of the scalar loss.
210+ Aggregate the loss for each condition in the batch .
147211
148- :param LabelTensor loss_value: the tensor of pointwise losses.
149- :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
150- :return: The computed scalar loss.
151- :rtype: LabelTensor
212+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
213+ tuple containing a condition name and a dictionary of points.
214+ :param int batch_idx: The index of the current batch.
215+ :param dict kwargs: Additional keyword arguments passed to
216+ ``optimization_cycle``.
217+ :return: The losses computed for all conditions in the batch, casted
218+ to a subclass of :class:`torch.Tensor`. It should return a dict
219+ containing the condition name and the associated scalar loss.
220+ :rtype: dict
152221 """
153- if self .loss .reduction == "mean" :
154- ret = torch .mean (loss_value )
155- elif self .loss .reduction == "sum" :
156- ret = torch .sum (loss_value )
157- else :
158- raise RuntimeError (
159- f"Invalid reduction, got { self .loss .reduction } "
160- "but expected mean or sum."
222+ # compute non-aggregated residuals
223+ residuals = self .optimization_cycle (batch )
224+
225+ # update weights based on residuals
226+ self ._update_weights (batch , batch_idx , residuals )
227+
228+ # compute losses
229+ losses = {}
230+ for cond , res in residuals .items ():
231+
232+ # Get the correct indices for the weights. Modulus is used according
233+ # to the number of points in the condition, as in the PinaDataset.
234+ len_res = len (res )
235+ idx = torch .arange (
236+ batch_idx * len_res ,
237+ (batch_idx + 1 ) * len_res ,
238+ device = res .device ,
239+ ) % len (self .problem .input_pts [cond ])
240+
241+ losses [cond ] = self ._apply_reduction (
242+ loss = (res * self .weights [cond ][idx ])
161243 )
162- return ret
163244
164- def loss_phys (self , samples , equation ):
245+ # store log
246+ self .store_log (
247+ f"{ cond } _loss" , losses [cond ].item (), self .get_batch_size (batch )
248+ )
249+
250+ # clamp unknown parameters in InverseProblem (if needed)
251+ self ._clamp_params ()
252+
253+ # aggregate
254+ loss = self .weighting .aggregate (losses ).as_subclass (torch .Tensor )
255+
256+ return loss
257+
258+ def _update_weights (self , batch , batch_idx , residuals ):
165259 """
166- Computes the physics loss for the physics-informed solver based on the
167- provided samples and equation.
260+ Update weights based on residuals.
168261
169- :param LabelTensor samples: The samples to evaluate the physics loss.
170- :param EquationInterface equation: The governing equation.
171- :return: The computed physics loss.
172- :rtype: LabelTensor
262+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
263+ tuple containing a condition name and a dictionary of points.
264+ :param int batch_idx: The index of the current batch.
265+ :param dict residuals: A dictionary containing the residuals for each
266+ condition. The keys are the condition names and the values are the
267+ residuals as tensors.
173268 """
174- residual = self . compute_residual ( samples = samples , equation = equation )
175- cond = self . current_condition_name
269+ # Iterate over each condition in the batch
270+ for cond , data in batch :
176271
177- r_norm = (
178- self .eta
179- * torch .abs (residual )
180- / (torch .max (torch .abs (residual )) + 1e-12 )
181- )
182- self .weights [cond ] = (self .gamma * self .weights [cond ] + r_norm ).detach ()
272+ # Compute normalized residuals
273+ res = residuals [cond ]
274+ res_abs = res .abs ()
275+ r_norm = (self .eta * res_abs ) / (res_abs .max () + 1e-12 )
183276
184- loss_value = self ._vectorial_loss (
185- torch .zeros_like (residual , requires_grad = True ), residual
186- )
277+ # Get the correct indices for the weights. Modulus is used according
278+ # to the number of points in the condition, as in the PinaDataset.
279+ len_pts = len (data ["input" ])
280+ idx = torch .arange (
281+ batch_idx * len_pts ,
282+ (batch_idx + 1 ) * len_pts ,
283+ device = res .device ,
284+ ) % len (self .problem .input_pts [cond ])
187285
188- return self ._vect_to_scalar (self .weights [cond ] ** 2 * loss_value )
286+ # Update weights
287+ weights = self .weights [cond ]
288+ update = self .gamma * weights [idx ] + r_norm
289+ weights [idx ] = update .detach ()
290+
291+ def _apply_reduction (self , loss ):
292+ """
293+ Apply the specified reduction to the loss. The reduction is deferred
294+ until the end of the optimization cycle to allow residual-based weights
295+ to be applied to each point beforehand.
296+
297+ :param torch.Tensor loss: The loss tensor to be reduced.
298+ :return: The reduced loss tensor.
299+ :rtype: torch.Tensor
300+ :raises ValueError: If the reduction method is neither "mean" nor "sum".
301+ """
302+ # Apply the specified reduction method
303+ if self ._reduction == "mean" :
304+ return loss .mean ()
305+ if self ._reduction == "sum" :
306+ return loss .sum ()
307+
308+ # Raise an error if the reduction method is not recognized
309+ raise ValueError (
310+ f"Unknown reduction: { self ._reduction } ."
311+ " Supported reductions are 'mean' and 'sum'."
312+ )
0 commit comments