Skip to content

Commit a5c5a35

Browse files
committed
Merge pull request #4 from lucasb-eyer/training-mode
[WIP] Compile different functions for training and prediction modes.
2 parents dd6bbf0 + 2b30e83 commit a5c5a35

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

DeepFried2/layers/Module.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ class Module:
77
def __init__(self):
88
self.training_mode = True
99

10-
self.fn_forward = None
11-
self.fn_accum_grads = None
12-
self.fn_accum_stats = None
10+
# The functions are stored in a dictionary whose keys correspond to the
11+
# values that `self.training_mode` can take. That way, it would be
12+
# trivial to extend to further modes, and the code avoids many branches.
13+
self.fn_forward = {}
14+
self.fn_accum_grads = {}
15+
self.fn_accum_stats = {}
1316

1417
#def __hash__(self):
1518
# raise NotImplementedError("You *need* to reimplement hash, even if it's just python's default. See the documentation for more info.")
@@ -44,15 +47,18 @@ def symb_forward(self, symb_input):
4447
raise NotImplementedError
4548

4649
def forward(self, data):
47-
if self.fn_forward is None:
50+
if self.training_mode not in self.fn_forward:
4851
symb_in = _T.TensorType(_th.config.floatX, (False,) * data.ndim)('X')
4952
symb_out = self.symb_forward(symb_in)
50-
self.fn_forward = _th.function(inputs=[symb_in], outputs=symb_out)
53+
self.fn_forward[self.training_mode] = _th.function(
54+
inputs=[symb_in],
55+
outputs=symb_out
56+
)
5157

52-
return self.fn_forward(data)
58+
return self.fn_forward[self.training_mode](data)
5359

5460
def accumulate_gradients(self, data_in, data_tgt, loss):
55-
if self.fn_accum_grads is None:
61+
if self.training_mode not in self.fn_accum_grads:
5662
symb_in = _T.TensorType(_th.config.floatX, (False,) * data_in.ndim)('X')
5763
symb_tgt = _T.TensorType(_th.config.floatX, (False,) * data_tgt.ndim)('T')
5864
symb_out = self.symb_forward(symb_in)
@@ -62,19 +68,19 @@ def accumulate_gradients(self, data_in, data_tgt, loss):
6268
symb_grads = _th.grad(cost=symb_err, wrt=params)
6369

6470
grads_updates = [(grad, grad + symb_grad) for grad, symb_grad in zip(grads, symb_grads)]
65-
self.fn_accum_grads = _th.function(
71+
self.fn_accum_grads[self.training_mode] = _th.function(
6672
inputs=[symb_in, symb_tgt],
6773
outputs=symb_err,
6874
updates=grads_updates
6975
)
7076

71-
return self.fn_accum_grads(data_in, data_tgt)
77+
return self.fn_accum_grads[self.training_mode](data_in, data_tgt)
7278

7379
def get_stat_updates(self):
7480
return []
7581

7682
def accumulate_statistics(self, data_in):
77-
if self.fn_accum_stats is None:
83+
if self.training_mode not in self.fn_accum_stats:
7884
symb_in = _T.TensorType(_th.config.floatX, (False,) * data_in.ndim)('X')
7985
self.symb_forward(symb_in)
8086

@@ -84,9 +90,9 @@ def accumulate_statistics(self, data_in):
8490
# compile and call a function. This prevents theano errors.
8591
return
8692

87-
self.fn_accum_stats = _th.function(
93+
self.fn_accum_stats[self.training_mode] = _th.function(
8894
inputs=[symb_in],
8995
updates=stat_updates
9096
)
9197

92-
self.fn_accum_stats(data_in)
98+
self.fn_accum_stats[self.training_mode](data_in)

0 commit comments

Comments
 (0)