Skip to content

Commit f07d26f

Browse files
committed
Fix bug on Pytorch for GPU
1 parent 2c586aa commit f07d26f

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

art/classifiers/pytorch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,21 +225,24 @@ def hook(grad):
225225
self._model.zero_grad()
226226
if label is None:
227227
for i in range(self.nb_classes):
228-
torch.autograd.backward(preds[:, i], torch.Tensor([1.] * len(preds[:, 0])), retain_graph=True)
228+
torch.autograd.backward(preds[:, i], torch.Tensor([1.] * len(preds[:, 0])).to(self._device),
229+
retain_graph=True)
229230

230231
grads = np.swapaxes(np.array(grads), 0, 1)
231232
grads = self._apply_processing_gradient(grads)
232233

233234
elif isinstance(label, (int, np.integer)):
234-
torch.autograd.backward(preds[:, label], torch.Tensor([1.] * len(preds[:, 0])), retain_graph=True)
235+
torch.autograd.backward(preds[:, label], torch.Tensor([1.] * len(preds[:, 0])).to(self._device),
236+
retain_graph=True)
235237

236238
grads = np.swapaxes(np.array(grads), 0, 1)
237239
grads = self._apply_processing_gradient(grads)
238240

239241
else:
240242
unique_label = list(np.unique(label))
241243
for i in unique_label:
242-
torch.autograd.backward(preds[:, i], torch.Tensor([1.] * len(preds[:, 0])), retain_graph=True)
244+
torch.autograd.backward(preds[:, i], torch.Tensor([1.] * len(preds[:, 0])).to(self._device),
245+
retain_graph=True)
243246

244247
grads = np.swapaxes(np.array(grads), 0, 1)
245248
lst = [unique_label.index(i) for i in label]

0 commit comments

Comments
 (0)