Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions pina/solver/garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def _train_generator(self, parameters, snapshots):
:return: The residual loss and the generator loss.
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()
self.optimizer_generator.instance.zero_grad()

# Generate a batch of images
generated_snapshots = self.sample(parameters)

# generator loss
Expand All @@ -165,7 +165,8 @@ def _train_generator(self, parameters, snapshots):

# backward step
g_loss.backward()
optimizer.step()
self.optimizer_generator.instance.step()
self.scheduler_generator.instance.step()

return r_loss, g_loss

Expand Down Expand Up @@ -196,8 +197,7 @@ def _train_discriminator(self, parameters, snapshots):
:return: The residual loss and the generator loss.
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
self.optimizer_discriminator.instance.zero_grad()

# Generate a batch of images
generated_snapshots = self.sample(parameters)
Expand All @@ -213,7 +213,8 @@ def _train_discriminator(self, parameters, snapshots):

# backward step
d_loss.backward()
optimizer.step()
self.optimizer_discriminator.instance.step()
self.scheduler_discriminator.instance.step()

return d_loss_real, d_loss_fake, d_loss

Expand Down Expand Up @@ -345,7 +346,7 @@ def optimizer_generator(self):
:return: The optimizer for the generator.
:rtype: Optimizer
"""
return self.optimizers[0].instance
return self.optimizers[0]

@property
def optimizer_discriminator(self):
Expand All @@ -355,7 +356,7 @@ def optimizer_discriminator(self):
:return: The optimizer for the discriminator.
:rtype: Optimizer
"""
return self.optimizers[1].instance
return self.optimizers[1]

@property
def scheduler_generator(self):
Expand All @@ -365,7 +366,7 @@ def scheduler_generator(self):
:return: The scheduler for the generator.
:rtype: Scheduler
"""
return self.schedulers[0].instance
return self.schedulers[0]

@property
def scheduler_discriminator(self):
Expand All @@ -375,4 +376,4 @@ def scheduler_discriminator(self):
:return: The scheduler for the discriminator.
:rtype: Scheduler
"""
return self.schedulers[1].instance
return self.schedulers[1]
4 changes: 4 additions & 0 deletions pina/solver/physics_informed_solver/competitive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,15 @@ def training_step(self, batch):
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()

# train discriminator
self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(-loss)
self.optimizer_discriminator.instance.step()
self.scheduler_discriminator.instance.step()

return loss

def loss_phys(self, samples, equation):
Expand Down
2 changes: 2 additions & 0 deletions pina/solver/physics_informed_solver/self_adaptive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,14 @@ def training_step(self, batch):
loss = super().training_step(batch)
self.manual_backward(-loss)
self.optimizer_weights.instance.step()
self.scheduler_weights.instance.step()

# Model optimization
self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()

return loss

Expand Down
Loading