From 8d3c7bd5caf653e16ed87c20ea42d05463f9fd2e Mon Sep 17 00:00:00 2001 From: Haneol Lee Date: Tue, 17 Oct 2023 09:44:51 +0900 Subject: [PATCH] Fix: loss calculation logic --- train_gan3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gan3d.py b/train_gan3d.py index 792d31f..e759d1f 100644 --- a/train_gan3d.py +++ b/train_gan3d.py @@ -102,7 +102,7 @@ def train(self): self.d_optim.zero_grad() predict_real, predict_id, predict_ex= self.discriminator(real_data) - error_real = self.criterion_gan(predict_real, make_ones(batch_size).to(device)) + self.criterion_class(predict_ex, label_ex) + self.criterion_class(predict_id, label_ex) + error_real = self.criterion_gan(predict_real, make_ones(batch_size).to(device)) + self.criterion_class(predict_ex, label_ex) + self.criterion_class(predict_id, label_id) error_real.backward() predict_fake, fake_id, fake_ex = self.discriminator(fake_data)