Skip to content

Commit a453e4a

Browse files
committed
Add comments
1 parent 0e140d7 commit a453e4a

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/metrics/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def generate_images_and_stack_features(generator, discriminator, eval_model, num
4545
device=device,
4646
stylegan_update_emas=False,
4747
cal_trsp_cost=False)
48+
4849
with torch.no_grad():
4950
features, logits = eval_model.get_outputs(fake_images, quantize=quantize)
5051
probs = torch.nn.functional.softmax(logits, dim=1)

src/worker.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def train_discriminator(self, current_step):
269269
real_images = upfirdn2d.filter2d(real_images, f / f.sum())
270270
fake_images = upfirdn2d.filter2d(fake_images, f / f.sum())
271271

272+
# shuffle real and fake images (APA)
272273
if self.AUG.apply_apa:
273274
real_images = apa_aug.apply_apa_aug(real_images, fake_images.detach(), self.aa_p, self.local_rank)
274275

@@ -405,6 +406,7 @@ def train_discriminator(self, current_step):
405406
lecam_loss = torch.tensor(0., device=self.local_rank)
406407
dis_acml_loss += self.LOSS.lecam_lambda*lecam_loss
407408

409+
# apply r1_reg inside of training loop
408410
if self.LOSS.apply_r1_reg and not self.is_stylegan:
409411
self.r1_penalty = losses.cal_r1_reg(adv_output=real_dict["adv_output"], images=real_images, device=self.local_rank)
410412
dis_acml_loss += self.LOSS.r1_lambda*self.r1_penalty
@@ -440,6 +442,7 @@ def train_discriminator(self, current_step):
440442
else:
441443
self.OPTIMIZATION.d_optimizer.step()
442444

445+
# apply r1_reg outside of training loop
443446
if self.LOSS.apply_r1_reg and self.LOSS.r1_place == "outside_loop" and \
444447
(self.OPTIMIZATION.d_updates_per_step*current_step + step_index) % self.STYLEGAN.d_reg_interval == 0:
445448
self.OPTIMIZATION.d_optimizer.zero_grad()
@@ -487,6 +490,8 @@ def train_discriminator(self, current_step):
487490
if self.LOSS.apply_wc:
488491
for p in self.Dis.parameters():
489492
p.data.clamp_(-self.LOSS.wc_bound, self.LOSS.wc_bound)
493+
494+
# empty cache to discard used memory
490495
if self.RUN.empty_cache:
491496
torch.cuda.empty_cache()
492497
return real_cond_loss, dis_acml_loss
@@ -547,8 +552,8 @@ def train_generator(self, current_step):
547552
# calculate adv_output, embed, proxy, and cls_output using the discriminator
548553
fake_dict = self.Dis(fake_images_, fake_labels)
549554

555+
# accumulate discriminator output informations for logging
550556
if self.AUG.apply_ada or self.AUG.apply_apa:
551-
# accumulate discriminator output informations for logging
552557
self.dis_sign_fake += torch.tensor((fake_dict["adv_output"].sign().sum().item(),
553558
self.OPTIMIZATION.batch_size),
554559
device=self.local_rank)
@@ -599,6 +604,7 @@ def train_generator(self, current_step):
599604
fake_zcr_loss = -1 * self.l2_loss(fake_images, fake_images_eps)
600605
gen_acml_loss += self.LOSS.g_lambda * fake_zcr_loss
601606

607+
# compute infomation loss for InfoGAN
602608
if self.MODEL.info_type in ["discrete", "both"]:
603609
dim = self.MODEL.info_dim_discrete_c
604610
self.info_discrete_loss = 0.0
@@ -652,6 +658,7 @@ def train_generator(self, current_step):
652658
style_mixing_p=self.cfgs.STYLEGAN.style_mixing_p,
653659
stylegan_update_emas=False,
654660
cal_trsp_cost=True if self.LOSS.apply_lo else False)
661+
655662
# blur images for stylegan3-r
656663
if self.MODEL.backbone == "stylegan3" and self.STYLEGAN.stylegan3_cfg == "stylegan3-r" and self.blur_init_sigma != "N/A":
657664
blur_sigma = max(1 - (self.effective_batch_size * current_step) / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma
@@ -667,6 +674,8 @@ def train_generator(self, current_step):
667674
# if ema is True: update parameters of the Gen_ema in adaptive way
668675
if self.MODEL.apply_g_ema:
669676
self.ema.update(current_step)
677+
678+
# empty cache to discard used memory
670679
if self.RUN.empty_cache:
671680
torch.cuda.empty_cache()
672681
return gen_acml_loss

0 commit comments

Comments
 (0)