diff --git a/pina/solver/physic_informed_solver/competitive_pinn.py b/pina/solver/physic_informed_solver/competitive_pinn.py index 0225ea6a2..c52122862 100644 --- a/pina/solver/physic_informed_solver/competitive_pinn.py +++ b/pina/solver/physic_informed_solver/competitive_pinn.py @@ -160,22 +160,6 @@ def loss_phys(self, samples, equation): self._train_discriminator(samples, equation, discriminator_bets) return loss_val - def loss_data(self, input_pts, output_pts): - """ - The data loss for the CompetitivePINN solver. It computes the loss - between the network output against the true solution. - - :param LabelTensor input_tensor: The input to the neural networks. - :param LabelTensor output_tensor: The true solution to compare the - network solution. - :return: The computed data loss. - :rtype: torch.Tensor - """ - loss_val = super().loss_data(input_pts, output_pts) - # prepare for optimizer step called in training step - loss_val.backward() - return loss_val - def configure_optimizers(self): """ Optimizer configuration for the Competitive PINN solver. @@ -252,7 +236,6 @@ def _train_discriminator(self, samples, equation, discriminator_bets): ) # prepare for optimizer step called in training step self.manual_backward(loss_val) - return def _train_model(self, samples, equation, discriminator_bets): """ diff --git a/pina/solver/physic_informed_solver/self_adaptive_pinn.py b/pina/solver/physic_informed_solver/self_adaptive_pinn.py index 185643d9b..c64c49971 100644 --- a/pina/solver/physic_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physic_informed_solver/self_adaptive_pinn.py @@ -178,62 +178,19 @@ def training_step(self, batch): :return: The sum of the loss functions. :rtype: LabelTensor """ - self.optimizer_model.instance.zero_grad() + # Weights optimization self.optimizer_weights.instance.zero_grad() loss = super().training_step(batch) - self.optimizer_model.instance.step() + self.manual_backward(-loss) self.optimizer_weights.instance.step() - return loss - - def loss_phys(self, samples, equation): - """ - Computes the physics loss for the SAPINN solver based on given - samples and equation. - - :param LabelTensor samples: The samples to evaluate the physics loss. - :param EquationInterface equation: The governing equation - representing the physics. - :return: The physics loss calculated based on given - samples and equation. - :rtype: torch.Tensor - """ - # Train the weights - weighted_loss = self._loss_phys(samples, equation) - loss_value = -weighted_loss.as_subclass(torch.Tensor) - self.manual_backward(loss_value) - - # Detach samples from the existing computational graph and - # create a new one by setting requires_grad to True. - # In alternative set `retain_graph=True`. - samples = samples.detach() - samples.requires_grad_() # = True - - # Train the model - weighted_loss = self._loss_phys(samples, equation) - loss_value = weighted_loss.as_subclass(torch.Tensor) - self.manual_backward(loss_value) - return loss_value - - def loss_data(self, input_pts, output_pts): - """ - Computes the data loss for the SAPINN solver based on input and - output. It computes the loss between the - network output against the true solution. + # Model optimization + self.optimizer_model.instance.zero_grad() + loss = super().training_step(batch) + self.manual_backward(loss) + self.optimizer_model.instance.step() - :param LabelTensor input_pts: The input to the neural networks. - :param LabelTensor output_pts: The true solution to compare the - network solution. - :return: The computed data loss. - :rtype: torch.Tensor - """ - residual = self.forward(input_pts) - output_pts - loss = self._vectorial_loss( - torch.zeros_like(residual, requires_grad=True), residual - ) - loss_value = self._vect_to_scalar(loss).as_subclass(torch.Tensor) - self.manual_backward(loss_value) - return loss_value + return loss def configure_optimizers(self): """ @@ -330,7 +287,7 @@ def on_load_checkpoint(self, checkpoint): ) return super().on_load_checkpoint(checkpoint) - def _loss_phys(self, samples, equation): + def loss_phys(self, samples, equation): """ Computation of the physical loss for SelfAdaptive PINN solver.