Skip to content

Commit f48da47

Browse files
add scheduler step for multisolvers (#526)
1 parent b958c0f commit f48da47

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

pina/solver/garom.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def _train_generator(self, parameters, snapshots):
151151
:return: The residual loss and the generator loss.
152152
:rtype: tuple[torch.Tensor, torch.Tensor]
153153
"""
154-
optimizer = self.optimizer_generator
155-
optimizer.zero_grad()
154+
self.optimizer_generator.instance.zero_grad()
156155

156+
# Generate a batch of images
157157
generated_snapshots = self.sample(parameters)
158158

159159
# generator loss
@@ -165,7 +165,8 @@ def _train_generator(self, parameters, snapshots):
165165

166166
# backward step
167167
g_loss.backward()
168-
optimizer.step()
168+
self.optimizer_generator.instance.step()
169+
self.scheduler_generator.instance.step()
169170

170171
return r_loss, g_loss
171172

@@ -196,8 +197,7 @@ def _train_discriminator(self, parameters, snapshots):
196197
:return: The residual loss and the generator loss.
197198
:rtype: tuple[torch.Tensor, torch.Tensor]
198199
"""
199-
optimizer = self.optimizer_discriminator
200-
optimizer.zero_grad()
200+
self.optimizer_discriminator.instance.zero_grad()
201201

202202
# Generate a batch of images
203203
generated_snapshots = self.sample(parameters)
@@ -213,7 +213,8 @@ def _train_discriminator(self, parameters, snapshots):
213213

214214
# backward step
215215
d_loss.backward()
216-
optimizer.step()
216+
self.optimizer_discriminator.instance.step()
217+
self.scheduler_discriminator.instance.step()
217218

218219
return d_loss_real, d_loss_fake, d_loss
219220

@@ -345,7 +346,7 @@ def optimizer_generator(self):
345346
:return: The optimizer for the generator.
346347
:rtype: Optimizer
347348
"""
348-
return self.optimizers[0].instance
349+
return self.optimizers[0]
349350

350351
@property
351352
def optimizer_discriminator(self):
@@ -355,7 +356,7 @@ def optimizer_discriminator(self):
355356
:return: The optimizer for the discriminator.
356357
:rtype: Optimizer
357358
"""
358-
return self.optimizers[1].instance
359+
return self.optimizers[1]
359360

360361
@property
361362
def scheduler_generator(self):
@@ -365,7 +366,7 @@ def scheduler_generator(self):
365366
:return: The scheduler for the generator.
366367
:rtype: Scheduler
367368
"""
368-
return self.schedulers[0].instance
369+
return self.schedulers[0]
369370

370371
@property
371372
def scheduler_discriminator(self):
@@ -375,4 +376,4 @@ def scheduler_discriminator(self):
375376
:return: The scheduler for the discriminator.
376377
:rtype: Scheduler
377378
"""
378-
return self.schedulers[1].instance
379+
return self.schedulers[1]

pina/solver/physics_informed_solver/competitive_pinn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,15 @@ def training_step(self, batch):
130130
loss = super().training_step(batch)
131131
self.manual_backward(loss)
132132
self.optimizer_model.instance.step()
133+
self.scheduler_model.instance.step()
134+
133135
# train discriminator
134136
self.optimizer_discriminator.instance.zero_grad()
135137
loss = super().training_step(batch)
136138
self.manual_backward(-loss)
137139
self.optimizer_discriminator.instance.step()
140+
self.scheduler_discriminator.instance.step()
141+
138142
return loss
139143

140144
def loss_phys(self, samples, equation):

pina/solver/physics_informed_solver/self_adaptive_pinn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,14 @@ def training_step(self, batch):
188188
loss = super().training_step(batch)
189189
self.manual_backward(-loss)
190190
self.optimizer_weights.instance.step()
191+
self.scheduler_weights.instance.step()
191192

192193
# Model optimization
193194
self.optimizer_model.instance.zero_grad()
194195
loss = super().training_step(batch)
195196
self.manual_backward(loss)
196197
self.optimizer_model.instance.step()
198+
self.scheduler_model.instance.step()
197199

198200
return loss
199201

0 commit comments

Comments
 (0)