@@ -151,9 +151,9 @@ def _train_generator(self, parameters, snapshots):
151151 :return: The residual loss and the generator loss.
152152 :rtype: tuple[torch.Tensor, torch.Tensor]
153153 """
154- optimizer = self .optimizer_generator
155- optimizer .zero_grad ()
154+ self .optimizer_generator .instance .zero_grad ()
156155
156+ # Generate a batch of images
157157 generated_snapshots = self .sample (parameters )
158158
159159 # generator loss
@@ -165,7 +165,8 @@ def _train_generator(self, parameters, snapshots):
165165
166166 # backward step
167167 g_loss .backward ()
168- optimizer .step ()
168+ self .optimizer_generator .instance .step ()
169+ self .scheduler_generator .instance .step ()
169170
170171 return r_loss , g_loss
171172
@@ -196,8 +197,7 @@ def _train_discriminator(self, parameters, snapshots):
196197 :return: The residual loss and the generator loss.
197198 :rtype: tuple[torch.Tensor, torch.Tensor]
198199 """
199- optimizer = self .optimizer_discriminator
200- optimizer .zero_grad ()
200+ self .optimizer_discriminator .instance .zero_grad ()
201201
202202 # Generate a batch of images
203203 generated_snapshots = self .sample (parameters )
@@ -213,7 +213,8 @@ def _train_discriminator(self, parameters, snapshots):
213213
214214 # backward step
215215 d_loss .backward ()
216- optimizer .step ()
216+ self .optimizer_discriminator .instance .step ()
217+ self .scheduler_discriminator .instance .step ()
217218
218219 return d_loss_real , d_loss_fake , d_loss
219220
@@ -345,7 +346,7 @@ def optimizer_generator(self):
345346 :return: The optimizer for the generator.
346347 :rtype: Optimizer
347348 """
348- return self .optimizers [0 ]. instance
349+ return self .optimizers [0 ]
349350
350351 @property
351352 def optimizer_discriminator (self ):
@@ -355,7 +356,7 @@ def optimizer_discriminator(self):
355356 :return: The optimizer for the discriminator.
356357 :rtype: Optimizer
357358 """
358- return self .optimizers [1 ]. instance
359+ return self .optimizers [1 ]
359360
360361 @property
361362 def scheduler_generator (self ):
@@ -365,7 +366,7 @@ def scheduler_generator(self):
365366 :return: The scheduler for the generator.
366367 :rtype: Scheduler
367368 """
368- return self .schedulers [0 ]. instance
369+ return self .schedulers [0 ]
369370
370371 @property
371372 def scheduler_discriminator (self ):
@@ -375,4 +376,4 @@ def scheduler_discriminator(self):
375376 :return: The scheduler for the discriminator.
376377 :rtype: Scheduler
377378 """
378- return self .schedulers [1 ]. instance
379+ return self .schedulers [1 ]
0 commit comments