11import DeepFried2 as df
2- from DeepFried2 .utils import make_tensor_or_tensors , aslist
2+ from DeepFried2 .utils import tensors_for_ndarrays , flatten
33
44import numpy as _np
55
6- class Module :
6+ class Module ( object ) :
77
88 def __init__ (self ):
9- self .training_mode = True
9+ self ._mode = 'train'
1010
1111 # The functions are stored in a dictionary whose keys correspond to the
12- # values that `self.training_mode` can take. That way, it would be
13- # trivial to extend to further modes, and the code avoids many branches.
12+ # values that `self._mode` can take.
1413 self ._fn_forward = {}
1514 self ._fn_accum_grads = {}
1615 self ._fn_accum_stats = {}
1716
17+ # These will store the last gotten/produced symbolic input/output
18+ # expressions, respectively. The key is the current mode.
19+ self ._last_symb_inp = {}
20+ self ._last_symb_out = {}
21+
1822 #def __hash__(self):
1923 # raise NotImplementedError("You *need* to reimplement hash, even if it's just python's default. See the documentation for more info.")
2024
@@ -38,44 +42,79 @@ def parameters(self, trainable_only=False):
3842 return params
3943
4044 def evaluate (self ):
41- self .training_mode = False
45+ self ._mode = 'eval'
4246
4347 def training (self ):
44- self .training_mode = True
48+ self ._mode = 'train'
4549
4650 def symb_forward (self , symb_input ):
4751 raise NotImplementedError ("`{}` needs to implement `symb_forward` method." .format (df .utils .typename (self )))
4852
53+ def __call__ (self , symb_input ):
54+ # Keep track of the symbolic inputs/outputs for things such as `Backward` layer.
55+ self ._last_symb_inp [self ._mode ] = symb_input
56+ self ._last_symb_out [self ._mode ] = self .symb_forward (symb_input )
57+ return self ._last_symb_out [self ._mode ]
58+
4959 def forward (self , data ):
50- if self .training_mode not in self ._fn_forward :
51- symb_in = make_tensor_or_tensors (data , 'X' )
52- symb_out = self .symb_forward (symb_in )
53- self ._fn_forward [self .training_mode ] = df .th .function (
54- inputs = aslist (symb_in ),
55- outputs = symb_out
60+ if self ._mode not in self ._fn_forward :
61+ symb_in = tensors_for_ndarrays (data , 'X' )
62+ symb_out = self (symb_in )
63+ extra_out = self .get_extra_outputs ()
64+ fn = self ._fn_forward [self ._mode ] = df .th .function (
65+ inputs = flatten (symb_in ),
66+ outputs = flatten (symb_out ) + flatten (extra_out )
5667 )
68+ fn ._df2_extra = extra_out
5769
58- return self ._fn_forward [self .training_mode ](* aslist (data ))
70+ fn = self ._fn_forward [self ._mode ]
71+ outs = fn (* flatten (data ))
72+ return self ._collect_extra_outputs (fn , outs )
5973
60- def accumulate_gradients (self , data_in , data_tgt , loss ):
61- if self .training_mode not in self ._fn_accum_grads :
62- symb_in = make_tensor_or_tensors (data_in , 'X' )
63- symb_tgt = make_tensor_or_tensors (data_tgt , 'T' )
64- symb_out = self .symb_forward (symb_in )
65- symb_err = loss .full_symb_forward (symb_out , symb_tgt )
74+ def accumulate_gradients (self , data_in , data_tgt , crit ):
75+ if self ._mode not in self ._fn_accum_grads :
76+ symb_in = tensors_for_ndarrays (data_in , 'X' )
77+ symb_tgt = tensors_for_ndarrays (data_tgt , 'T' )
78+ symb_out = self (symb_in )
79+ symb_cost = crit (symb_out , symb_tgt )
80+ extra_out = self .get_extra_outputs () + crit .get_extra_outputs ()
6681
6782 params = self .parameters (trainable_only = True )
68- symb_grads = df .th .grad (cost = symb_err , wrt = [p .param for p in params ])
83+ symb_grads = df .th .grad (cost = symb_cost , wrt = [p .param for p in params ])
6984 grads_updates = [(p .grad , p .grad + symb_grad ) for p , symb_grad in zip (params , symb_grads )]
7085
71- self ._fn_accum_grads [self .training_mode ] = df .th .function (
72- inputs = aslist (symb_in ) + aslist (symb_tgt ),
73- outputs = symb_err ,
86+ fn = self ._fn_accum_grads [self ._mode ] = df .th .function (
87+ inputs = flatten (symb_in ) + flatten (symb_tgt ),
88+ outputs = flatten ( symb_cost ) + flatten ( extra_out ) ,
7489 updates = grads_updates
7590 )
91+ fn ._df2_extra = extra_out
92+
93+ fn = self ._fn_accum_grads [self ._mode ]
94+ args = flatten (data_in ) + flatten (data_tgt )
95+ outs = fn (* args )
96+ return self ._collect_extra_outputs (fn , outs )
97+
98+ def get_extra_outputs (self ):
99+ """
100+ Return a list of Theano expressions which will be passed as additional
101+ `output` parameters. The computed value will be stored in the
102+ expression's `val` attribute.
103+
104+ Guaranteed to be called after `symb_forward`.
105+ """
106+ return []
107+
108+ def _collect_extra_outputs (self , fn , vals ):
109+ # The number of non-extra outputs.
110+ nout = len (vals ) - len (fn ._df2_extra )
111+
112+ # Store all outputs in the `val` attribute so that they can possibly
113+ # be retrieved by the modules that asked for them.
114+ for out , val in zip (fn ._df2_extra , vals [nout :]):
115+ out .val = val
76116
77- args = aslist (data_in ) + aslist (data_tgt )
78- return self ._fn_accum_grads [self .training_mode ](* args )
117+ return vals [:nout ] if nout > 1 else vals [0 ]
79118
80119 def get_stat_updates (self ):
81120 """
@@ -88,12 +127,12 @@ def get_stat_updates(self):
88127 return []
89128
90129 def accumulate_statistics (self , data_in ):
91- if self .training_mode not in self ._fn_accum_stats :
92- symb_in = make_tensor_or_tensors (data_in , 'X' )
130+ if self ._mode not in self ._fn_accum_stats :
131+ symb_in = tensors_for_ndarrays (data_in , 'X' )
93132
94133 # Call forward once so it can compute some variables it'll actually
95134 # use in the stat updates collection.
96- self . symb_forward (symb_in )
135+ self (symb_in )
97136
98137 stat_updates = self .get_stat_updates ()
99138 if not stat_updates :
@@ -117,12 +156,12 @@ def accumulate_statistics(self, data_in):
117156 print ("WARNING: Dropped the following stat-update because that variable got multiple updates: {}" .format (upd [0 ]))
118157 stat_updates = uniq_updates
119158
120- self ._fn_accum_stats [self .training_mode ] = df .th .function (
121- inputs = aslist (symb_in ),
159+ self ._fn_accum_stats [self ._mode ] = df .th .function (
160+ inputs = flatten (symb_in ),
122161 updates = stat_updates
123162 )
124163
125- self ._fn_accum_stats [self .training_mode ](* aslist (data_in ))
164+ self ._fn_accum_stats [self ._mode ](* flatten (data_in ))
126165
127166 def clear (self ):
128167 self ._fn_forward .clear ()
0 commit comments