Skip to content

Commit 8cf0057

Browse files
author
Jaan Altosaar
authored
Merge pull request #15 from ischurov/size-mismatch-fix
FIX: Tensor size mismatch in VariationalMeanField.forward
2 parents 2e48e80 + 4e58b8c commit 8cf0057

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

train_variational_autoencoder_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(self, x, n_samples=1):
7171
scale = self.softplus(scale_arg)
7272
eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
7373
z = loc + scale * eps # reparameterization
74-
log_q_z = self.log_q_z(loc, scale, z).sum(-1)
74+
log_q_z = self.log_q_z(loc, scale, z).sum(-1, keepdim=True)
7575
return z, log_q_z
7676

7777

0 commit comments

Comments
 (0)