diff --git a/bnaf.py b/bnaf.py index 46f00ec..908a2b7 100644 --- a/bnaf.py +++ b/bnaf.py @@ -187,7 +187,7 @@ def forward(self, x, sum_logdets): # 3. compute output and logdet of the layer out = F.linear(x, w, self.bias) logdet = self.logg + self.weight - 0.5 * v_norm.pow(2).log() - logdet = logdet[self.mask_d.byte()] + logdet = logdet[self.mask_d.bool()] logdet = logdet.view(1, self.data_dim, out.shape[1]//self.data_dim, x.shape[1]//self.data_dim) \ .expand(x.shape[0],-1,-1,-1) # output (B, data_dim, out_dim // data_dim, in_dim // data_dim)