@@ -178,62 +178,19 @@ def training_step(self, batch):
178178 :return: The sum of the loss functions.
179179 :rtype: LabelTensor
180180 """
181- self . optimizer_model . instance . zero_grad ()
181+ # Weights optimization
182182 self .optimizer_weights .instance .zero_grad ()
183183 loss = super ().training_step (batch )
184- self .optimizer_model . instance . step ( )
184+ self .manual_backward ( - loss )
185185 self .optimizer_weights .instance .step ()
186- return loss
187-
188- def loss_phys (self , samples , equation ):
189- """
190- Computes the physics loss for the SAPINN solver based on given
191- samples and equation.
192-
193- :param LabelTensor samples: The samples to evaluate the physics loss.
194- :param EquationInterface equation: The governing equation
195- representing the physics.
196- :return: The physics loss calculated based on given
197- samples and equation.
198- :rtype: torch.Tensor
199- """
200- # Train the weights
201- weighted_loss = self ._loss_phys (samples , equation )
202- loss_value = - weighted_loss .as_subclass (torch .Tensor )
203- self .manual_backward (loss_value )
204-
205- # Detach samples from the existing computational graph and
206- # create a new one by setting requires_grad to True.
207- # In alternative set `retain_graph=True`.
208- samples = samples .detach ()
209- samples .requires_grad_ () # = True
210-
211- # Train the model
212- weighted_loss = self ._loss_phys (samples , equation )
213- loss_value = weighted_loss .as_subclass (torch .Tensor )
214- self .manual_backward (loss_value )
215186
216- return loss_value
217-
218- def loss_data (self , input_pts , output_pts ):
219- """
220- Computes the data loss for the SAPINN solver based on input and
221- output. It computes the loss between the
222- network output against the true solution.
187+ # Model optimization
188+ self .optimizer_model .instance .zero_grad ()
189+ loss = super ().training_step (batch )
190+ self .manual_backward (loss )
191+ self .optimizer_model .instance .step ()
223192
224- :param LabelTensor input_pts: The input to the neural networks.
225- :param LabelTensor output_pts: The true solution to compare the
226- network solution.
227- :return: The computed data loss.
228- :rtype: torch.Tensor
229- """
230- residual = self .forward (input_pts ) - output_pts
231- loss = self ._vectorial_loss (
232- torch .zeros_like (residual , requires_grad = True ), residual
233- )
234- loss_value = self ._vect_to_scalar (loss ).as_subclass (torch .Tensor )
235- self .manual_backward (loss_value )
236- return loss_value
193+ return loss
237194
238195 def configure_optimizers (self ):
239196 """
@@ -330,7 +287,7 @@ def on_load_checkpoint(self, checkpoint):
330287 )
331288 return super ().on_load_checkpoint (checkpoint )
332289
333- def _loss_phys (self , samples , equation ):
290+ def loss_phys (self , samples , equation ):
334291 """
335292 Computation of the physical loss for SelfAdaptive PINN solver.
336293
0 commit comments