1+ from .Module import Module
2+
3+ import numpy as _np
4+ import theano as _th
5+ import theano .tensor as _T
6+
7+
8+ class BatchNormalization (Module ):
9+ def __init__ (self , n_features , eps = None ):
10+ Module .__init__ (self )
11+
12+ self .weight = _th .shared (_np .ones (shape = (n_features ,), dtype = _th .config .floatX ))
13+ self .bias = _th .shared (_np .zeros (shape = (n_features , ), dtype = _th .config .floatX ))
14+ self .grad_weight = _th .shared (_np .zeros (shape = (n_features ,), dtype = _th .config .floatX ))
15+ self .grad_bias = _th .shared (_np .zeros (shape = (n_features , ), dtype = _th .config .floatX ))
16+
17+ self .inference_weight = _th .shared (_np .ones (shape = (n_features ,), dtype = _th .config .floatX ))
18+ self .inference_bias = _th .shared (_np .zeros (shape = (n_features , ), dtype = _th .config .floatX ))
19+
20+ self .buffer_variance = _th .shared (_np .ones (shape = (n_features , ), dtype = _th .config .floatX ))
21+ self .buffer_mean = _th .shared (_np .zeros (shape = (n_features , ), dtype = _th .config .floatX ))
22+ self .buffer_counts = _th .shared (_np .asarray (0. , dtype = _th .config .floatX ))
23+
24+ self .eps = eps or 1e-5
25+
26+ self .batch_mean = None
27+ self .batch_var = None
28+
29+ def symb_forward (self , symb_input ):
30+ d_shuffle = ('x' , 0 )
31+ axis = (0 ,)
32+
33+ if symb_input .ndim == 4 :
34+ d_shuffle += ('x' , 'x' )
35+ axis += (2 , 3 )
36+
37+ if self .training_mode :
38+ self .batch_mean = _th .tensor .mean (symb_input , axis = axis )
39+ self .batch_var = _th .tensor .var (symb_input , axis = axis )
40+
41+ return (symb_input - self .batch_mean .dimshuffle (* d_shuffle )) / _th .tensor .sqrt (self .batch_var + self .eps ).dimshuffle (* d_shuffle ) * self .weight .dimshuffle (* d_shuffle ) + self .bias .dimshuffle (* d_shuffle )
42+ else :
43+ return symb_input * self .inference_weight .dimshuffle (* d_shuffle ) + self .inference_bias .dimshuffle (* d_shuffle )
44+
45+ def get_stat_updates (self ,):
46+ assert (self .batch_mean is not None ) and (self .batch_var is not None ), "You need to do a forward pass first"
47+
48+ stat_updates = list ()
49+ stat_updates .append ((self .buffer_mean ,
50+ (self .buffer_mean * self .buffer_counts + self .batch_mean ) / (self .buffer_counts + 1.0 )))
51+
52+ stat_updates .append ((self .buffer_variance ,
53+ (self .buffer_variance * self .buffer_counts + self .batch_var ) / (self .buffer_counts + 1.0 )))
54+
55+ stat_updates .append ((self .buffer_counts ,
56+ self .buffer_counts + 1.0 ))
57+
58+ return stat_updates
59+
60+ def training (self ):
61+ Module .training (self )
62+ self .buffer_counts .set_value (0 )
63+ self .batch_mean = None
64+ self .batch_var = None
65+
66+ def evaluate (self ):
67+ Module .evaluate (self )
68+ self .inference_weight .set_value (self .weight .get_value () / _np .sqrt (self .buffer_variance .get_value () + self .eps ))
69+ self .inference_bias .set_value (self .bias .get_value () - self .inference_weight .get_value () * self .buffer_mean .get_value ())
0 commit comments