Skip to content

Commit 93d7435

Browse files
committed
Merge pull request #12 from lucasb-eyer/bn-param-names
Fix names of BN shared variables.
2 parents a5c5a35 + 9dccbc6 commit 93d7435

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

DeepFried2/layers/BatchNormalization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ class BatchNormalization(Module):
1111
def __init__(self, n_features, eps=None):
1212
Module.__init__(self)
1313

14-
self.weight, self.grad_weight = create_param_and_grad(n_features, const(1), 'W_BN')
15-
self.bias, self.grad_bias = create_param_and_grad(n_features, const(0), 'b_BN')
14+
self.weight, self.grad_weight = create_param_and_grad(n_features, const(1), name='W_BN_{}'.format(n_features))
15+
self.bias, self.grad_bias = create_param_and_grad(n_features, const(0), name='b_BN_{}'.format(n_features))
1616

17-
self.inference_weight = create_param(n_features, const(1), 'W_BN_inf')
18-
self.inference_bias = create_param(n_features, const(0), 'b_BN_inf')
17+
self.inference_weight = create_param(n_features, const(1), name='W_BN_{}_inf'.format(n_features))
18+
self.inference_bias = create_param(n_features, const(0), name='b_BN_{}_inf'.format(n_features))
1919

2020
# These are buffers for collecting the minibatch statistics.
21-
self.buffer_variance = create_param(n_features, const(1), 'BN_var')
22-
self.buffer_mean = create_param(n_features, const(0), 'BN_mean')
23-
self.buffer_counts = _th.shared(_np.asarray(0, dtype=_th.config.floatX))
21+
self.buffer_variance = create_param(n_features, const(1), name='BN_var_{}'.format(n_features))
22+
self.buffer_mean = create_param(n_features, const(0), name='BN_mean_{}'.format(n_features))
23+
self.buffer_counts = _th.shared(_np.asarray(0, dtype=_th.config.floatX), name='BN_count_{}'.format(n_features))
2424

2525
self.eps = eps or 1e-5
2626

0 commit comments

Comments
 (0)