Skip to content

Commit 5373cd3

Browse files
committed
Actually return the computed cost/error.
1 parent 2730d01 commit 5373cd3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

DeepFried2/layers/Module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ def accumulate_gradients(self, data_in, data_tgt, loss):
6464
grads_updates = [(grad, grad + symb_grad) for grad, symb_grad in zip(grads, symb_grads)]
6565
self.fn_accum_grads = _th.function(
6666
inputs=[symb_in, symb_tgt],
67+
outputs=symb_err,
6768
updates=grads_updates
6869
)
6970

70-
self.fn_accum_grads(data_in, data_tgt)
71+
return self.fn_accum_grads(data_in, data_tgt)
7172

7273
def get_stat_updates(self):
7374
return []

0 commit comments

Comments
 (0)