Skip to content

Commit 2622b33

Browse files
committed
update pytorch classifier
1 parent 6765ffb commit 2622b33

File tree

3 files changed

+219
-112
lines changed

3 files changed

+219
-112
lines changed

art/attacks/carlini_unittest.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -227,45 +227,3 @@ def test_ptclassifier(self):
227227

228228

229229

230-
231-
232-
233-
234-
235-
236-
237-
238-
239-
240-
241-
242-
243-
244-
245-
246-
247-
248-
249-
250-
251-
252-
253-
254-
255-
256-
257-
258-
259-
260-
261-
262-
263-
264-
265-
266-
267-
268-
269-
270-
271-

art/classifiers/pytorch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def predict(self, inputs, logits=False):
6868
# if not logits:
6969
# exp = np.exp(preds - np.max(preds, axis=1, keepdims=True))
7070
# preds = exp / np.sum(exp, axis=1, keepdims=True)
71-
(logit_output, output) = self._model(torch.from_numpy(inputs))
71+
(logit_output, output) = self._model(torch.from_numpy(inputs).float())
7272

7373
if logits:
7474
preds = logit_output.detach().numpy()
@@ -114,6 +114,9 @@ def fit(self, inputs, outputs, batch_size=128, nb_epochs=10):
114114
i_batch = torch.from_numpy(inputs[ind[m*batch_size:]])
115115
o_batch = torch.from_numpy(outputs[ind[m * batch_size:]])
116116

117+
# Cast to float
118+
i_batch = i_batch.float()
119+
117120
# Zero the parameter gradients
118121
self._optimizer.zero_grad()
119122

@@ -137,6 +140,7 @@ def class_gradient(self, inputs, logits=False):
137140
"""
138141
# Convert the inputs to Tensors
139142
x = torch.from_numpy(inputs)
143+
x = x.float()
140144
x.requires_grad = True
141145

142146
# Compute the gradient and return
@@ -177,6 +181,7 @@ def loss_gradient(self, inputs, labels):
177181
"""
178182
# Convert the inputs to Tensors
179183
inputs_t = torch.from_numpy(inputs)
184+
inputs_t = inputs_t.float()
180185
inputs_t.requires_grad = True
181186

182187
# Convert the labels to Tensors
@@ -219,3 +224,6 @@ def loss_gradient(self, inputs, labels):
219224
# return results
220225

221226

227+
228+
229+

0 commit comments

Comments
 (0)