|
3 | 3 | from mozi.utils.theano_utils import shared_zeros |
4 | 4 | from mozi.weight_init import UniformWeight |
5 | 5 | import theano.tensor as T |
| 6 | +import theano |
| 7 | +floatX = theano.config.floatX |
6 | 8 |
|
7 | 9 | class BatchNormalization(Template): |
8 | | - ''' |
9 | | - Adapted From keras |
10 | | - REFERENCE: |
11 | | - Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift |
12 | | - http://arxiv.org/pdf/1502.03167v3.pdf |
13 | | -
|
14 | | - mode: 0 -> featurewise normalization |
15 | | - 1 -> samplewise normalization (may sometimes outperform featurewise mode) |
16 | | -
|
17 | | - momentum: momentum term in the computation of a running estimate of the mean and std of the data |
18 | | - ''' |
19 | | - def __init__(self, input_shape, epsilon=1e-6, mode=0, momentum=0.9): |
| 10 | + |
| 11 | + def __init__(self, input_shape, epsilon=1e-6, mode=0, gamma_init=UniformWeight(), memory=0.9): |
| 12 | + ''' |
| 13 | + REFERENCE: |
| 14 | + Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift |
| 15 | + http://arxiv.org/pdf/1502.03167v3.pdf |
| 16 | + PARAMS: |
| 17 | + memory: |
| 18 | + y_t is the latest value, the moving average x_tp1 is calculated as |
| 19 | + x_tp1 = memory * y_t + (1-memory) * x_t, the larger the memory, the |
| 20 | + more weight is put on contempory. |
| 21 | + epsilon: |
| 22 | + denominator min value for preventing division by zero in computing std |
| 23 | + ''' |
20 | 24 | self.input_shape = input_shape |
21 | 25 | self.epsilon = epsilon |
22 | | - self.mode = mode |
23 | | - self.momentum = momentum |
| 26 | + self.mem = memory |
24 | 27 |
|
25 | | - self.init = UniformWeight() |
26 | | - self.gamma = self.init((self.input_shape), name='gamma') |
| 28 | + self.gamma = gamma_init(self.input_shape, name='gamma') |
27 | 29 | self.beta = shared_zeros(self.input_shape, name='beta') |
28 | 30 |
|
29 | | - self.running_mean = None |
30 | | - self.running_std = None |
| 31 | + self.moving_mean = 0 |
| 32 | + self.moving_std = 0 |
31 | 33 |
|
32 | 34 | self.params = [self.gamma, self.beta] |
33 | 35 |
|
34 | 36 |
|
35 | 37 | def _train_fprop(self, state_below): |
| 38 | + miu = state_below.mean(axis=0) |
| 39 | + std = T.std(state_below, axis=0) |
| 40 | + Z = (state_below - miu) / (std + self.epsilon) |
36 | 41 |
|
37 | | - if self.mode == 0: |
38 | | - m = state_below.mean(axis=0) |
39 | | - # manual computation of std to prevent NaNs |
40 | | - std = T.mean((state_below-m)**2 + self.epsilon, axis=0) ** 0.5 |
41 | | - X_normed = (state_below - m) / (std + self.epsilon) |
42 | | - |
43 | | - if self.running_mean is None: |
44 | | - self.running_mean = m |
45 | | - self.running_std = std |
46 | | - else: |
47 | | - self.running_mean *= self.momentum |
48 | | - self.running_mean += (1-self.momentum) * m |
49 | | - self.running_std *= self.momentum |
50 | | - self.running_std += (1-self.momentum) * std |
51 | | - |
52 | | - elif self.mode == 1: |
53 | | - m = state_below.mean(axis=-1, keepdims=True) |
54 | | - std = state_below.std(axis=-1, keepdims=True) |
55 | | - X_normed = (state_below - m) / (std + self.epsilon) |
| 42 | + self.moving_mean += self.mem * miu + (1-self.mem) * self.moving_mean |
| 43 | + self.moving_std += self.mem * std + (1-self.mem) * self.moving_std |
56 | 44 |
|
57 | | - return self.gamma * X_normed + self.beta |
| 45 | + return self.gamma * Z + self.beta |
58 | 46 |
|
59 | 47 |
|
60 | 48 | def _test_fprop(self, state_below): |
61 | | - |
62 | | - if self.mode == 0: |
63 | | - X_normed = (state_below - self.running_mean) / (self.running_std + self.epsilon) |
64 | | - |
65 | | - elif self.mode == 1: |
66 | | - m = state_below.mean(axis=-1, keepdims=True) |
67 | | - std = state_below.std(axis=-1, keepdims=True) |
68 | | - X_normed = (state_below - m) / (std + self.epsilon) |
69 | | - |
70 | | - return self.gamma * X_normed + self.beta |
| 49 | + Z = (state_below - self.moving_mean) / (self.moving_std + self.epsilon) |
| 50 | + return self.gamma * Z + self.beta |
71 | 51 |
|
72 | 52 |
|
73 | 53 | class LRN(Template): |
74 | 54 | """ |
| 55 | + Adapted from pylearn2 |
75 | 56 | Local Response Normalization |
76 | 57 | """ |
77 | 58 |
|
|
0 commit comments