@@ -70,9 +70,11 @@ def forward(self, data):
7070 symb_in = tensors_for_ndarrays (data , 'X' )
7171 symb_out = self (symb_in )
7272 extra_out = self .get_extra_outputs ()
73+ extra_up = self .get_extra_updates ()
7374 fn = self ._fn_forward [self ._mode ] = df .th .function (
7475 inputs = flatten (symb_in ),
75- outputs = flatten (symb_out ) + flatten (extra_out )
76+ outputs = flatten (symb_out ) + flatten (extra_out ),
77+ updates = flatten (extra_up , types = list ),
7678 )
7779 fn ._df2_extra = extra_out
7880
@@ -87,6 +89,7 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
8789 symb_out = self (symb_in )
8890 symb_cost = crit (symb_out , symb_tgt )
8991 extra_out = self .get_extra_outputs () + crit .get_extra_outputs ()
92+ extra_up = self .get_extra_updates ()
9093
9194 params = self .parameters (learnable_only = True )
9295 symb_grads = df .th .grad (cost = symb_cost , wrt = [p .param for p in params ])
@@ -95,7 +98,7 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
9598 fn = self ._fn_accum_grads [self ._mode , id (crit )] = df .th .function (
9699 inputs = flatten (symb_in ) + flatten (symb_tgt ),
97100 outputs = flatten (symb_cost ) + flatten (extra_out ),
98- updates = grads_updates
101+ updates = grads_updates + flatten ( extra_up , types = list ),
99102 )
100103 fn ._df2_extra = extra_out
101104
@@ -114,6 +117,10 @@ def get_extra_outputs(self):
114117 """
115118 return []
116119
120+ def get_extra_updates (self ):
121+ """NOTE: MUST BE LIST OF TUPLES (because of how flatten is called)"""
122+ return []
123+
117124 def _collect_extra_outputs (self , fn , vals ):
118125 # The number of non-extra outputs.
119126 nout = len (vals ) - len (fn ._df2_extra )
@@ -148,6 +155,7 @@ def accumulate_statistics(self, data_in):
148155 # If there's no layer collecting statistics, we don't need to
149156 # compile and call a function. This prevents theano errors.
150157 return
158+ extra_up = self .get_extra_updates ()
151159
152160 # Need to make sure there's only one update per variable for the
153161 # case where we've got the same module instance at multiple places
@@ -167,7 +175,7 @@ def accumulate_statistics(self, data_in):
167175
168176 self ._fn_accum_stats [self ._mode ] = df .th .function (
169177 inputs = flatten (symb_in ),
170- updates = stat_updates
178+ updates = stat_updates + flatten ( extra_up , types = list )
171179 )
172180
173181 self ._fn_accum_stats [self ._mode ](* flatten (data_in ))
0 commit comments