Skip to content

Commit cc3a5ac

Browse files
authored
Merge pull request #93 from lucasb-eyer/extra-updates
Add ability for layers to give extra updates.
2 parents 10179b5 + 809a5f7 commit cc3a5ac

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

DeepFried2/Container.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def parameters(self, *a, **kw):
3434
def get_extra_outputs(self):
3535
return list(_chain.from_iterable(m.get_extra_outputs() for m in self.modules))
3636

37+
def get_extra_updates(self):
38+
return list(_chain.from_iterable(m.get_extra_updates() for m in self.modules))
39+
3740
def get_stat_updates(self):
3841
return list(_chain.from_iterable(m.get_stat_updates() for m in self.modules))
3942

DeepFried2/Module.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ def forward(self, data):
7070
symb_in = tensors_for_ndarrays(data, 'X')
7171
symb_out = self(symb_in)
7272
extra_out = self.get_extra_outputs()
73+
extra_up = self.get_extra_updates()
7374
fn = self._fn_forward[self._mode] = df.th.function(
7475
inputs=flatten(symb_in),
75-
outputs=flatten(symb_out) + flatten(extra_out)
76+
outputs=flatten(symb_out) + flatten(extra_out),
77+
updates=flatten(extra_up, types=list),
7678
)
7779
fn._df2_extra = extra_out
7880

@@ -87,6 +89,7 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
8789
symb_out = self(symb_in)
8890
symb_cost = crit(symb_out, symb_tgt)
8991
extra_out = self.get_extra_outputs() + crit.get_extra_outputs()
92+
extra_up = self.get_extra_updates()
9093

9194
params = self.parameters(learnable_only=True)
9295
symb_grads = df.th.grad(cost=symb_cost, wrt=[p.param for p in params])
@@ -95,7 +98,7 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
9598
fn = self._fn_accum_grads[self._mode, id(crit)] = df.th.function(
9699
inputs=flatten(symb_in) + flatten(symb_tgt),
97100
outputs=flatten(symb_cost) + flatten(extra_out),
98-
updates=grads_updates
101+
updates=grads_updates + flatten(extra_up, types=list),
99102
)
100103
fn._df2_extra = extra_out
101104

@@ -114,6 +117,10 @@ def get_extra_outputs(self):
114117
"""
115118
return []
116119

120+
def get_extra_updates(self):
121+
"""NOTE: MUST BE LIST OF TUPLES (because of how flatten is called)"""
122+
return []
123+
117124
def _collect_extra_outputs(self, fn, vals):
118125
# The number of non-extra outputs.
119126
nout = len(vals) - len(fn._df2_extra)
@@ -148,6 +155,7 @@ def accumulate_statistics(self, data_in):
148155
# If there's no layer collecting statistics, we don't need to
149156
# compile and call a function. This prevents theano errors.
150157
return
158+
extra_up = self.get_extra_updates()
151159

152160
# Need to make sure there's only one update per variable for the
153161
# case where we've got the same module instance at multiple places
@@ -167,7 +175,7 @@ def accumulate_statistics(self, data_in):
167175

168176
self._fn_accum_stats[self._mode] = df.th.function(
169177
inputs=flatten(symb_in),
170-
updates=stat_updates
178+
updates=stat_updates + flatten(extra_up, types=list)
171179
)
172180

173181
self._fn_accum_stats[self._mode](*flatten(data_in))

0 commit comments

Comments
 (0)