diff --git a/pina/solver/garom.py b/pina/solver/garom.py index 2f763a700..b854ce786 100644 --- a/pina/solver/garom.py +++ b/pina/solver/garom.py @@ -151,9 +151,9 @@ def _train_generator(self, parameters, snapshots): :return: The residual loss and the generator loss. :rtype: tuple[torch.Tensor, torch.Tensor] """ - optimizer = self.optimizer_generator - optimizer.zero_grad() + self.optimizer_generator.instance.zero_grad() + # Generate a batch of images generated_snapshots = self.sample(parameters) # generator loss @@ -165,7 +165,8 @@ def _train_generator(self, parameters, snapshots): # backward step g_loss.backward() - optimizer.step() + self.optimizer_generator.instance.step() + self.scheduler_generator.instance.step() return r_loss, g_loss @@ -196,8 +197,7 @@ def _train_discriminator(self, parameters, snapshots): :return: The residual loss and the generator loss. :rtype: tuple[torch.Tensor, torch.Tensor] """ - optimizer = self.optimizer_discriminator - optimizer.zero_grad() + self.optimizer_discriminator.instance.zero_grad() # Generate a batch of images generated_snapshots = self.sample(parameters) @@ -213,7 +213,8 @@ def _train_discriminator(self, parameters, snapshots): # backward step d_loss.backward() - optimizer.step() + self.optimizer_discriminator.instance.step() + self.scheduler_discriminator.instance.step() return d_loss_real, d_loss_fake, d_loss @@ -345,7 +346,7 @@ def optimizer_generator(self): :return: The optimizer for the generator. :rtype: Optimizer """ - return self.optimizers[0].instance + return self.optimizers[0] @property def optimizer_discriminator(self): @@ -355,7 +356,7 @@ def optimizer_discriminator(self): :return: The optimizer for the discriminator. :rtype: Optimizer """ - return self.optimizers[1].instance + return self.optimizers[1] @property def scheduler_generator(self): @@ -365,7 +366,7 @@ def scheduler_generator(self): :return: The scheduler for the generator. :rtype: Scheduler """ - return self.schedulers[0].instance + return self.schedulers[0] @property def scheduler_discriminator(self): @@ -375,4 +376,4 @@ def scheduler_discriminator(self): :return: The scheduler for the discriminator. :rtype: Scheduler """ - return self.schedulers[1].instance + return self.schedulers[1] diff --git a/pina/solver/physics_informed_solver/competitive_pinn.py b/pina/solver/physics_informed_solver/competitive_pinn.py index 058c53f40..a61755148 100644 --- a/pina/solver/physics_informed_solver/competitive_pinn.py +++ b/pina/solver/physics_informed_solver/competitive_pinn.py @@ -130,11 +130,15 @@ def training_step(self, batch): loss = super().training_step(batch) self.manual_backward(loss) self.optimizer_model.instance.step() + self.scheduler_model.instance.step() + # train discriminator self.optimizer_discriminator.instance.zero_grad() loss = super().training_step(batch) self.manual_backward(-loss) self.optimizer_discriminator.instance.step() + self.scheduler_discriminator.instance.step() + return loss def loss_phys(self, samples, equation): diff --git a/pina/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/solver/physics_informed_solver/self_adaptive_pinn.py index a6310d515..1c534b01a 100644 --- a/pina/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physics_informed_solver/self_adaptive_pinn.py @@ -188,12 +188,14 @@ def training_step(self, batch): loss = super().training_step(batch) self.manual_backward(-loss) self.optimizer_weights.instance.step() + self.scheduler_weights.instance.step() # Model optimization self.optimizer_model.instance.zero_grad() loss = super().training_step(batch) self.manual_backward(loss) self.optimizer_model.instance.step() + self.scheduler_model.instance.step() return loss