Skip to content

Commit 10179b5

Browse files
authored
Merge pull request #94 from lucasb-eyer/fix-swapping-crit
Make using varying criteria work until #86.
2 parents 2e869dc + fcd6cf4 commit 10179b5

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

DeepFried2/Module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def forward(self, data):
8181
return self._collect_extra_outputs(fn, outs)
8282

8383
def accumulate_gradients(self, data_in, data_tgt, crit):
84-
if self._mode not in self._fn_accum_grads:
84+
if (self._mode, id(crit)) not in self._fn_accum_grads:
8585
symb_in = tensors_for_ndarrays(data_in, 'X')
8686
symb_tgt = tensors_for_ndarrays(data_tgt, 'T')
8787
symb_out = self(symb_in)
@@ -92,14 +92,14 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
9292
symb_grads = df.th.grad(cost=symb_cost, wrt=[p.param for p in params])
9393
grads_updates = [(p.grad, p.grad + symb_grad) for p, symb_grad in zip(params, symb_grads)]
9494

95-
fn = self._fn_accum_grads[self._mode] = df.th.function(
95+
fn = self._fn_accum_grads[self._mode, id(crit)] = df.th.function(
9696
inputs=flatten(symb_in) + flatten(symb_tgt),
9797
outputs=flatten(symb_cost) + flatten(extra_out),
9898
updates=grads_updates
9999
)
100100
fn._df2_extra = extra_out
101101

102-
fn = self._fn_accum_grads[self._mode]
102+
fn = self._fn_accum_grads[self._mode, id(crit)]
103103
args = flatten(data_in) + flatten(data_tgt)
104104
outs = fn(*args)
105105
return self._collect_extra_outputs(fn, outs)

DeepFried2/tests/test_Module.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python3
2+
3+
import DeepFried2 as df
4+
5+
import unittest
6+
import numpy as np
7+
8+
class TestModule(unittest.TestCase):
9+
10+
def testDifferentCriteriaInstances(self):
11+
T = np.random.randn(10,10).astype(df.floatX)
12+
c1 = df.MSECriterion()
13+
c2 = df.MADCriterion()
14+
err = 0.5
15+
16+
net = df.Identity()
17+
l1 = float(net.accumulate_gradients(T+err, T, c1))
18+
l2 = float(net.accumulate_gradients(T+err, T, c2))
19+
20+
np.testing.assert_almost_equal(l1, err**2)
21+
np.testing.assert_almost_equal(l2, abs(err))

0 commit comments

Comments
 (0)