Skip to content

Commit 8b8a864

Browse files
authored
Merge pull request #98 from lucasb-eyer/bn-momentum
BN momentum
2 parents 651850b + 08b0db4 commit 8b8a864

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

DeepFried2/layers/BatchNormalization.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77

88
class BatchNormalization(df.Module):
9-
def __init__(self, n_features, eps=1e-5):
9+
def __init__(self, n_features, momentum=False, eps=1e-5):
1010
"""
1111
- `n_features` may be an integer (#features, #feature-maps for images) or a tuple.
1212
- If a single integer, it indicates the size of the 1-axis, i.e. first feature-axis.
1313
This is the only axis that will be normalized using statistics across all other axes.
1414
- If a tuple, it indicates the sizes of multiple axes (starting at 1) which are
1515
considered feature-axes and will consequently be normalized over statistics across all other axes.
16+
- `momentum` means statistics are collected as running (geometric) statistics
17+
during training. It should be the decay value within (0,1).
1618
- `eps` is a small number which is added to the variance in order to
1719
avoid computing sqrt(0) for features with zero variance.
1820
"""
@@ -30,6 +32,7 @@ def __init__(self, n_features, eps=1e-5):
3032
self.buf_var = df.th.shared(_np.full(n_features, 1, df.floatX), name='BN_var_{}'.format(n_features))
3133
self.buf_mean = df.th.shared(_np.full(n_features, 0, df.floatX), name='BN_mean_{}'.format(n_features))
3234
self.buf_count = df.th.shared(_np.asarray(0, dtype=df.floatX), name='BN_count_{}'.format(n_features))
35+
self.momentum = momentum
3336

3437
self.eps = eps or 1e-5
3538

@@ -74,6 +77,16 @@ def get_stat_updates(self):
7477
(self.buf_count, self.buf_count + 1.0),
7578
]
7679

80+
def get_extra_updates(self):
81+
if self._mode == 'train' and self.momentum != False:
82+
return [
83+
(self.buf_mean, self.momentum*self.buf_mean + (1-self.momentum)*self.batch_mean),
84+
(self.buf_var, self.momentum*self.buf_var + (1-self.momentum)*self.batch_var),
85+
(self.buf_count, 1.0),
86+
]
87+
else:
88+
return []
89+
7790
def training(self):
7891
df.Module.training(self)
7992
self.buf_count.set_value(0)

DeepFried2/zoo/resnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ def symb_forward(self, symb_inputs):
1111
return s
1212

1313

14-
def block(nchan, fs=(3,3), body=None):
14+
def block(nchan, fs=(3,3), body=None, bnmom=False):
1515
return df.Sequential(
1616
df.RepeatInput(
1717
df.Sequential(
18-
df.BatchNormalization(nchan), df.ReLU(),
18+
df.BatchNormalization(nchan, bnmom), df.ReLU(),
1919
df.SpatialConvolutionCUDNN(nchan, nchan, fs, border='same', init=df.init.prelu(), bias=False),
20-
df.BatchNormalization(nchan), df.ReLU(),
20+
df.BatchNormalization(nchan, bnmom), df.ReLU(),
2121
df.SpatialConvolutionCUDNN(nchan, nchan, fs, border='same', init=df.init.prelu(), bias=False)
2222
) if body is None else body,
2323
df.Identity()
@@ -26,13 +26,13 @@ def block(nchan, fs=(3,3), body=None):
2626
)
2727

2828

29-
def block_proj(nin, nout, fs=(3,3), body=None):
29+
def block_proj(nin, nout, fs=(3,3), body=None, bnmom=False):
3030
return df.Sequential(
3131
df.RepeatInput(
3232
df.Sequential(
33-
df.BatchNormalization(nin), df.ReLU(),
33+
df.BatchNormalization(nin, bnmom), df.ReLU(),
3434
df.SpatialConvolutionCUDNN(nin, nout, fs, border='same', init=df.init.prelu(), bias=False),
35-
df.BatchNormalization(nout), df.ReLU(),
35+
df.BatchNormalization(nout, bnmom), df.ReLU(),
3636
df.SpatialConvolutionCUDNN(nout, nout, fs, border='same', init=df.init.prelu(), bias=False)
3737
) if body is None else body,
3838
df.SpatialConvolutionCUDNN(nin, nout, (1,)*len(fs)),

0 commit comments

Comments
 (0)