Skip to content

Commit 86adb6d

Browse files
committed
fix batchnorm
1 parent c3e47ed commit 86adb6d

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

bayesian_torch/layers/batchnorm.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def _check_input_dim(self, input):
5454
input.dim()))
5555

5656
def forward(self, input):
57-
self._check_input_dim(input[0])
5857
exponential_average_factor = 0.0
5958
if self.training and self.track_running_stats:
6059
self.num_batches_tracked += 1
@@ -63,13 +62,21 @@ def forward(self, input):
6362
else: # use exponential moving average
6463
exponential_average_factor = self.momentum
6564

66-
out = F.batch_norm(input[0], self.running_mean, self.running_var,
67-
self.weight, self.bias, self.training
68-
or not self.track_running_stats,
69-
exponential_average_factor, self.eps)
70-
kl = 0
71-
return out, kl
72-
65+
if len(input) == 2:
66+
self._check_input_dim(input[0])
67+
out = F.batch_norm(input[0], self.running_mean, self.running_var,
68+
self.weight, self.bias, self.training
69+
or not self.track_running_stats,
70+
exponential_average_factor, self.eps)
71+
kl = 0
72+
return out, kl
73+
else:
74+
out = F.batch_norm(input, self.running_mean, self.running_var,
75+
self.weight, self.bias, self.training
76+
or not self.track_running_stats,
77+
exponential_average_factor, self.eps)
78+
return out
79+
7380

7481
class BatchNorm1dLayer(nn.Module):
7582
def __init__(self,

bayesian_torch/models/bayesian/resnet_variational_large.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
200200
posterior_mu_init=posterior_mu_init,
201201
posterior_rho_init=posterior_rho_init,
202202
bias=False),
203-
nn.BatchNorm2d(planes * block.expansion),
203+
BatchNorm2dLayer(planes * block.expansion),
204204
)
205205

206206
layers = []

0 commit comments

Comments
 (0)