@@ -81,7 +81,7 @@ def forward(self, data):
8181 return self ._collect_extra_outputs (fn , outs )
8282
8383 def accumulate_gradients (self , data_in , data_tgt , crit ):
84- if self ._mode not in self ._fn_accum_grads :
84+ if ( self ._mode , id ( crit )) not in self ._fn_accum_grads :
8585 symb_in = tensors_for_ndarrays (data_in , 'X' )
8686 symb_tgt = tensors_for_ndarrays (data_tgt , 'T' )
8787 symb_out = self (symb_in )
@@ -92,14 +92,14 @@ def accumulate_gradients(self, data_in, data_tgt, crit):
9292 symb_grads = df .th .grad (cost = symb_cost , wrt = [p .param for p in params ])
9393 grads_updates = [(p .grad , p .grad + symb_grad ) for p , symb_grad in zip (params , symb_grads )]
9494
95- fn = self ._fn_accum_grads [self ._mode ] = df .th .function (
95+ fn = self ._fn_accum_grads [self ._mode , id ( crit ) ] = df .th .function (
9696 inputs = flatten (symb_in ) + flatten (symb_tgt ),
9797 outputs = flatten (symb_cost ) + flatten (extra_out ),
9898 updates = grads_updates
9999 )
100100 fn ._df2_extra = extra_out
101101
102- fn = self ._fn_accum_grads [self ._mode ]
102+ fn = self ._fn_accum_grads [self ._mode , id ( crit ) ]
103103 args = flatten (data_in ) + flatten (data_tgt )
104104 outs = fn (* args )
105105 return self ._collect_extra_outputs (fn , outs )
0 commit comments