From b9da605b4f08f311371abc9512df8517cff2fbce Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Sat, 1 Mar 2025 11:44:12 +0100 Subject: [PATCH] fixing competitive pinn --- .../competitive_pinn.py | 85 ++++--------------- 1 file changed, 18 insertions(+), 67 deletions(-) diff --git a/pina/solver/physic_informed_solver/competitive_pinn.py b/pina/solver/physic_informed_solver/competitive_pinn.py index c52122862..8eddef8ca 100644 --- a/pina/solver/physic_informed_solver/competitive_pinn.py +++ b/pina/solver/physic_informed_solver/competitive_pinn.py @@ -3,7 +3,7 @@ import torch import copy -from pina.problem import InverseProblem +from ...problem import InverseProblem from .pinn_interface import PINNInterface from ..solver import MultiSolverInterface @@ -125,10 +125,15 @@ def training_step(self, batch): :return: The sum of the loss functions. :rtype: LabelTensor """ + # train model self.optimizer_model.instance.zero_grad() - self.optimizer_discriminator.instance.zero_grad() loss = super().training_step(batch) + self.manual_backward(loss) self.optimizer_model.instance.step() + # train discriminator + self.optimizer_discriminator.instance.zero_grad() + loss = super().training_step(batch) + self.manual_backward(-loss) self.optimizer_discriminator.instance.step() return loss @@ -144,20 +149,18 @@ def loss_phys(self, samples, equation): samples and equation. :rtype: LabelTensor """ - # Train the model for one step - with torch.no_grad(): - discriminator_bets = self.discriminator(samples) - loss_val = self._train_model(samples, equation, discriminator_bets) - - # Detach samples from the existing computational graph and - # create a new one by setting requires_grad to True. - # In alternative set `retain_graph=True`. - samples = samples.detach() - samples.requires_grad_() - - # Train the discriminator for one step + # Compute discriminator bets discriminator_bets = self.discriminator(samples) - self._train_discriminator(samples, equation, discriminator_bets) + + # Compute residual and multiply discriminator_bets + residual = self.compute_residual(samples=samples, equation=equation) + residual = residual * discriminator_bets + + # Compute competitive residual. + loss_val = self.loss( + torch.zeros_like(residual, requires_grad=True), + residual, + ) return loss_val def configure_optimizers(self): @@ -213,58 +216,6 @@ def on_train_batch_end(self, outputs, batch, batch_idx): return super().on_train_batch_end(outputs, batch, batch_idx) - def _train_discriminator(self, samples, equation, discriminator_bets): - """ - Trains the discriminator network of the Competitive PINN. - - :param LabelTensor samples: Input samples to evaluate the physics loss. - :param EquationInterface equation: The governing equation representing - the physics. - :param Tensor discriminator_bets: Predictions made by the discriminator - network. - """ - # Compute residual. Detach since discriminator weights are fixed - residual = self.compute_residual( - samples=samples, equation=equation - ).detach() - - # Compute competitive residual, then maximise the loss - competitive_residual = residual * discriminator_bets - loss_val = -self.loss( - torch.zeros_like(competitive_residual, requires_grad=True), - competitive_residual, - ) - # prepare for optimizer step called in training step - self.manual_backward(loss_val) - - def _train_model(self, samples, equation, discriminator_bets): - """ - Trains the model network of the Competitive PINN. - - :param LabelTensor samples: Input samples to evaluate the physics loss. - :param EquationInterface equation: The governing equation representing - the physics. - :param Tensor discriminator_bets: Predictions made by the discriminator. - network. - :return: The computed data loss. - :rtype: torch.Tensor - """ - # Compute residual - residual = self.compute_residual(samples=samples, equation=equation) - with torch.no_grad(): - loss_residual = self.loss(torch.zeros_like(residual), residual) - - # Compute competitive residual. Detach discriminator_bets - # to optimize only the generator model - competitive_residual = residual * discriminator_bets.detach() - loss_val = self.loss( - torch.zeros_like(competitive_residual, requires_grad=True), - competitive_residual, - ) - # prepare for optimizer step called in training step - self.manual_backward(loss_val) - return loss_residual - @property def neural_net(self): """