Skip to content

Commit 4e58b8c

Browse files
committed
fixes error with size mismatch
1 parent 2e48e80 commit 4e58b8c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

train_variational_autoencoder_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(self, x, n_samples=1):
7171
scale = self.softplus(scale_arg)
7272
eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
7373
z = loc + scale * eps # reparameterization
74-
log_q_z = self.log_q_z(loc, scale, z).sum(-1)
74+
log_q_z = self.log_q_z(loc, scale, z).sum(-1, keepdim=True)
7575
return z, log_q_z
7676

7777

0 commit comments

Comments
 (0)