@@ -225,21 +225,24 @@ def hook(grad):
225225 self ._model .zero_grad ()
226226 if label is None :
227227 for i in range (self .nb_classes ):
228- torch .autograd .backward (preds [:, i ], torch .Tensor ([1. ] * len (preds [:, 0 ])), retain_graph = True )
228+ torch .autograd .backward (preds [:, i ], torch .Tensor ([1. ] * len (preds [:, 0 ])).to (self ._device ),
229+ retain_graph = True )
229230
230231 grads = np .swapaxes (np .array (grads ), 0 , 1 )
231232 grads = self ._apply_processing_gradient (grads )
232233
233234 elif isinstance (label , (int , np .integer )):
234- torch .autograd .backward (preds [:, label ], torch .Tensor ([1. ] * len (preds [:, 0 ])), retain_graph = True )
235+ torch .autograd .backward (preds [:, label ], torch .Tensor ([1. ] * len (preds [:, 0 ])).to (self ._device ),
236+ retain_graph = True )
235237
236238 grads = np .swapaxes (np .array (grads ), 0 , 1 )
237239 grads = self ._apply_processing_gradient (grads )
238240
239241 else :
240242 unique_label = list (np .unique (label ))
241243 for i in unique_label :
242- torch .autograd .backward (preds [:, i ], torch .Tensor ([1. ] * len (preds [:, 0 ])), retain_graph = True )
244+ torch .autograd .backward (preds [:, i ], torch .Tensor ([1. ] * len (preds [:, 0 ])).to (self ._device ),
245+ retain_graph = True )
243246
244247 grads = np .swapaxes (np .array (grads ), 0 , 1 )
245248 lst = [unique_label .index (i ) for i in label ]
0 commit comments