From ba63680d4cff86974bedcde62b790864c9a81854 Mon Sep 17 00:00:00 2001 From: miaobuao Date: Mon, 22 Apr 2024 19:52:51 +0000 Subject: [PATCH] fix(cluster-gan): `valid` shape error --- models/cluster_gan/clustergan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/cluster_gan/clustergan.py b/models/cluster_gan/clustergan.py index 213e4e5..38ecfbd 100644 --- a/models/cluster_gan/clustergan.py +++ b/models/cluster_gan/clustergan.py @@ -280,7 +280,7 @@ def execute(self, img): # Train Generator # ----------------- - if ((i % n_skip_iter) == 0): + if ((i % n_skip_iter) == 0) or real_imgs.shape[0] != batch_size: (enc_gen_zn, enc_gen_zc, enc_gen_zc_logits) = encoder(gen_imgs) zn_loss = mse_loss(enc_gen_zn, zn) zc_loss = xe_loss(enc_gen_zc_logits, zc_idx)