@@ -269,6 +269,7 @@ def train_discriminator(self, current_step):
269269 real_images = upfirdn2d .filter2d (real_images , f / f .sum ())
270270 fake_images = upfirdn2d .filter2d (fake_images , f / f .sum ())
271271
272+ # shuffle real and fake images (APA)
272273 if self .AUG .apply_apa :
273274 real_images = apa_aug .apply_apa_aug (real_images , fake_images .detach (), self .aa_p , self .local_rank )
274275
@@ -405,6 +406,7 @@ def train_discriminator(self, current_step):
405406 lecam_loss = torch .tensor (0. , device = self .local_rank )
406407 dis_acml_loss += self .LOSS .lecam_lambda * lecam_loss
407408
409+ # apply r1_reg inside of training loop
408410 if self .LOSS .apply_r1_reg and not self .is_stylegan :
409411 self .r1_penalty = losses .cal_r1_reg (adv_output = real_dict ["adv_output" ], images = real_images , device = self .local_rank )
410412 dis_acml_loss += self .LOSS .r1_lambda * self .r1_penalty
@@ -440,6 +442,7 @@ def train_discriminator(self, current_step):
440442 else :
441443 self .OPTIMIZATION .d_optimizer .step ()
442444
445+ # apply r1_reg outside of training loop
443446 if self .LOSS .apply_r1_reg and self .LOSS .r1_place == "outside_loop" and \
444447 (self .OPTIMIZATION .d_updates_per_step * current_step + step_index ) % self .STYLEGAN .d_reg_interval == 0 :
445448 self .OPTIMIZATION .d_optimizer .zero_grad ()
@@ -487,6 +490,8 @@ def train_discriminator(self, current_step):
487490 if self .LOSS .apply_wc :
488491 for p in self .Dis .parameters ():
489492 p .data .clamp_ (- self .LOSS .wc_bound , self .LOSS .wc_bound )
493+
494+ # empty cache to discard used memory
490495 if self .RUN .empty_cache :
491496 torch .cuda .empty_cache ()
492497 return real_cond_loss , dis_acml_loss
@@ -547,8 +552,8 @@ def train_generator(self, current_step):
547552 # calculate adv_output, embed, proxy, and cls_output using the discriminator
548553 fake_dict = self .Dis (fake_images_ , fake_labels )
549554
555+ # accumulate discriminator output informations for logging
550556 if self .AUG .apply_ada or self .AUG .apply_apa :
551- # accumulate discriminator output informations for logging
552557 self .dis_sign_fake += torch .tensor ((fake_dict ["adv_output" ].sign ().sum ().item (),
553558 self .OPTIMIZATION .batch_size ),
554559 device = self .local_rank )
@@ -599,6 +604,7 @@ def train_generator(self, current_step):
599604 fake_zcr_loss = - 1 * self .l2_loss (fake_images , fake_images_eps )
600605 gen_acml_loss += self .LOSS .g_lambda * fake_zcr_loss
601606
607+ # compute infomation loss for InfoGAN
602608 if self .MODEL .info_type in ["discrete" , "both" ]:
603609 dim = self .MODEL .info_dim_discrete_c
604610 self .info_discrete_loss = 0.0
@@ -652,6 +658,7 @@ def train_generator(self, current_step):
652658 style_mixing_p = self .cfgs .STYLEGAN .style_mixing_p ,
653659 stylegan_update_emas = False ,
654660 cal_trsp_cost = True if self .LOSS .apply_lo else False )
661+
655662 # blur images for stylegan3-r
656663 if self .MODEL .backbone == "stylegan3" and self .STYLEGAN .stylegan3_cfg == "stylegan3-r" and self .blur_init_sigma != "N/A" :
657664 blur_sigma = max (1 - (self .effective_batch_size * current_step ) / (self .blur_fade_kimg * 1e3 ), 0 ) * self .blur_init_sigma
@@ -667,6 +674,8 @@ def train_generator(self, current_step):
667674 # if ema is True: update parameters of the Gen_ema in adaptive way
668675 if self .MODEL .apply_g_ema :
669676 self .ema .update (current_step )
677+
678+ # empty cache to discard used memory
670679 if self .RUN .empty_cache :
671680 torch .cuda .empty_cache ()
672681 return gen_acml_loss
0 commit comments