Skip to content

Commit bce965c

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Fix PyTorch CUDA issue (tentative)
1 parent 52505e4 commit bce965c

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

art/classifiers/pytorch.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(self, clip_values, model, loss, optimizer, input_shape, nb_classes,
4242
super(PyTorchClassifier, self).__init__(clip_values=clip_values, channel_index=channel_index, defences=defences,
4343
preprocessing=preprocessing)
4444

45-
# self._nb_classes = list(model.modules())[-1 if use_logits else -2].out_features
4645
self._nb_classes = nb_classes
4746
self._input_shape = input_shape
4847
self._model = PyTorchClassifier.ModelWrapper(model)
@@ -57,8 +56,8 @@ def __init__(self, clip_values, model, loss, optimizer, input_shape, nb_classes,
5756

5857
# Use GPU if possible
5958
import torch
60-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61-
self._model.to(device)
59+
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60+
self._model.to(self._device)
6261

6362
def predict(self, x, logits=False):
6463
"""
@@ -85,13 +84,14 @@ def predict(self, x, logits=False):
8584
# if not logits:
8685
# exp = np.exp(preds - np.max(preds, axis=1, keepdims=True))
8786
# preds = exp / np.sum(exp, axis=1, keepdims=True)
88-
model_outputs = self._model(torch.from_numpy(x_).float())
87+
88+
model_outputs = self._model(torch.from_numpy(x_).to(self._device).float())
8989
(logit_output, output) = (model_outputs[-2], model_outputs[-1])
9090

9191
if logits:
92-
preds = logit_output.detach().numpy()
92+
preds = logit_output.detach().cpu().numpy()
9393
else:
94-
preds = output.detach().numpy()
94+
preds = output.detach().cpu().numpy()
9595

9696
return preds
9797

@@ -130,11 +130,11 @@ def fit(self, x, y, batch_size=128, nb_epochs=10):
130130
# Train for one epoch
131131
for m in range(num_batch):
132132
if m < num_batch - 1:
133-
i_batch = torch.from_numpy(x_[ind[m * batch_size:(m + 1) * batch_size]])
134-
o_batch = torch.from_numpy(y_[ind[m * batch_size:(m + 1) * batch_size]])
133+
i_batch = torch.from_numpy(x_[ind[m * batch_size:(m + 1) * batch_size]]).to(self._device)
134+
o_batch = torch.from_numpy(y_[ind[m * batch_size:(m + 1) * batch_size]]).to(self._device)
135135
else:
136-
i_batch = torch.from_numpy(x_[ind[m * batch_size:]])
137-
o_batch = torch.from_numpy(y_[ind[m * batch_size:]])
136+
i_batch = torch.from_numpy(x_[ind[m * batch_size:]]).to(self._device)
137+
o_batch = torch.from_numpy(y_[ind[m * batch_size:]]).to(self._device)
138138

139139
# Cast to float
140140
i_batch = i_batch.float()
@@ -170,7 +170,7 @@ def class_gradient(self, x, label=None, logits=False):
170170
raise ValueError('Label %s is out of range.' % label)
171171

172172
# Convert the inputs to Tensors
173-
x_ = torch.from_numpy(self._apply_processing(x))
173+
x_ = torch.from_numpy(self._apply_processing(x)).to(self._device)
174174
x_ = x_.float()
175175
x_.requires_grad = True
176176

@@ -191,8 +191,8 @@ def class_gradient(self, x, label=None, logits=False):
191191
# Compute the gradient
192192
if label is not None:
193193
self._model.zero_grad()
194-
torch.autograd.backward(preds[:, label], torch.FloatTensor([1] * len(preds[:, 0])), retain_graph=True)
195-
grds = x_.grad.numpy().copy()
194+
torch.autograd.backward(preds[:, label], torch.Tensor([1.] * len(preds[:, 0])), retain_graph=True)
195+
grds = x_.grad.cpu().numpy().copy()
196196
x_.grad.data.zero_()
197197

198198
grds = np.expand_dims(self._apply_processing_gradient(grds), axis=1)
@@ -201,8 +201,8 @@ def class_gradient(self, x, label=None, logits=False):
201201
grds = []
202202
self._model.zero_grad()
203203
for i in range(self.nb_classes):
204-
torch.autograd.backward(preds[:, i], torch.FloatTensor([1] * len(preds[:, 0])), retain_graph=True)
205-
grds.append(x_.grad.numpy().copy())
204+
torch.autograd.backward(preds[:, i], torch.Tensor([1.] * len(preds[:, 0])), retain_graph=True)
205+
grds.append(x_.grad.cpu().numpy().copy())
206206
x_.grad.data.zero_()
207207

208208
grds = np.swapaxes(np.array(grds), 0, 1)
@@ -225,12 +225,12 @@ def loss_gradient(self, x, y):
225225
import torch
226226

227227
# Convert the inputs to Tensors
228-
inputs_t = torch.from_numpy(self._apply_processing(x))
228+
inputs_t = torch.from_numpy(self._apply_processing(x)).to(self._device)
229229
inputs_t = inputs_t.float()
230230
inputs_t.requires_grad = True
231231

232232
# Convert the labels to Tensors
233-
labels_t = torch.from_numpy(np.argmax(y, axis=1))
233+
labels_t = torch.from_numpy(np.argmax(y, axis=1)).to(self._device)
234234

235235
# Compute the gradient and return
236236
model_outputs = self._model(inputs_t)
@@ -242,7 +242,7 @@ def loss_gradient(self, x, y):
242242

243243
# Compute gradients
244244
loss.backward()
245-
grds = inputs_t.grad.numpy().copy()
245+
grds = inputs_t.grad.cpu().numpy().copy()
246246
grds = self._apply_processing_gradient(grds)
247247
assert grds.shape == x.shape
248248

@@ -286,7 +286,7 @@ def get_activations(self, x, layer):
286286
self._model.train(False)
287287

288288
# Run prediction
289-
model_outputs = self._model(torch.from_numpy(x).float())[:-1]
289+
model_outputs = self._model(torch.from_numpy(x).to(self._device).float())[:-1]
290290

291291
if isinstance(layer, six.string_types):
292292
if layer not in self._layer_names:
@@ -299,7 +299,7 @@ def get_activations(self, x, layer):
299299
else:
300300
raise TypeError("Layer must be of type str or int")
301301

302-
return model_outputs[layer_index].detach().numpy()
302+
return model_outputs[layer_index].detach().cpu().numpy()
303303

304304
# def _forward_at(self, inputs, layer):
305305
# """

0 commit comments

Comments
 (0)