@@ -7,9 +7,12 @@ class Module:
77 def __init__ (self ):
88 self .training_mode = True
99
10- self .fn_forward = None
11- self .fn_accum_grads = None
12- self .fn_accum_stats = None
10+ # The functions are stored in a dictionary whose keys correspond to the
11+ # values that `self.training_mode` can take. That way, it would be
12+ # trivial to extend to further modes, and the code avoids many branches.
13+ self .fn_forward = {}
14+ self .fn_accum_grads = {}
15+ self .fn_accum_stats = {}
1316
1417 #def __hash__(self):
1518 # raise NotImplementedError("You *need* to reimplement hash, even if it's just python's default. See the documentation for more info.")
@@ -44,15 +47,18 @@ def symb_forward(self, symb_input):
4447 raise NotImplementedError
4548
4649 def forward (self , data ):
47- if self .fn_forward is None :
50+ if self .training_mode not in self . fn_forward :
4851 symb_in = _T .TensorType (_th .config .floatX , (False ,) * data .ndim )('X' )
4952 symb_out = self .symb_forward (symb_in )
50- self .fn_forward = _th .function (inputs = [symb_in ], outputs = symb_out )
53+ self .fn_forward [self .training_mode ] = _th .function (
54+ inputs = [symb_in ],
55+ outputs = symb_out
56+ )
5157
52- return self .fn_forward (data )
58+ return self .fn_forward [ self . training_mode ] (data )
5359
5460 def accumulate_gradients (self , data_in , data_tgt , loss ):
55- if self .fn_accum_grads is None :
61+ if self .training_mode not in self . fn_accum_grads :
5662 symb_in = _T .TensorType (_th .config .floatX , (False ,) * data_in .ndim )('X' )
5763 symb_tgt = _T .TensorType (_th .config .floatX , (False ,) * data_tgt .ndim )('T' )
5864 symb_out = self .symb_forward (symb_in )
@@ -62,19 +68,19 @@ def accumulate_gradients(self, data_in, data_tgt, loss):
6268 symb_grads = _th .grad (cost = symb_err , wrt = params )
6369
6470 grads_updates = [(grad , grad + symb_grad ) for grad , symb_grad in zip (grads , symb_grads )]
65- self .fn_accum_grads = _th .function (
71+ self .fn_accum_grads [ self . training_mode ] = _th .function (
6672 inputs = [symb_in , symb_tgt ],
6773 outputs = symb_err ,
6874 updates = grads_updates
6975 )
7076
71- return self .fn_accum_grads (data_in , data_tgt )
77+ return self .fn_accum_grads [ self . training_mode ] (data_in , data_tgt )
7278
7379 def get_stat_updates (self ):
7480 return []
7581
7682 def accumulate_statistics (self , data_in ):
77- if self .fn_accum_stats is None :
83+ if self .training_mode not in self . fn_accum_stats :
7884 symb_in = _T .TensorType (_th .config .floatX , (False ,) * data_in .ndim )('X' )
7985 self .symb_forward (symb_in )
8086
@@ -84,9 +90,9 @@ def accumulate_statistics(self, data_in):
8490 # compile and call a function. This prevents theano errors.
8591 return
8692
87- self .fn_accum_stats = _th .function (
93+ self .fn_accum_stats [ self . training_mode ] = _th .function (
8894 inputs = [symb_in ],
8995 updates = stat_updates
9096 )
9197
92- self .fn_accum_stats (data_in )
98+ self .fn_accum_stats [ self . training_mode ] (data_in )
0 commit comments