Skip to content

Commit 4bfed67

Browse files
authored
Merge pull request #2180 from f4str/torch-dataloaders
Optimize PyTorch Classifiers and Object Detectors
2 parents 011ab1e + f3fcf19 commit 4bfed67

File tree

9 files changed

+541
-559
lines changed

9 files changed

+541
-559
lines changed

art/estimators/certification/randomized_smoothing/pytorch.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING
26+
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
2727

2828
import warnings
29-
import random
3029
from tqdm import tqdm
3130
import numpy as np
3231

@@ -137,7 +136,7 @@ def fit( # pylint: disable=W0221
137136
nb_epochs: int = 10,
138137
training_mode: bool = True,
139138
drop_last: bool = False,
140-
scheduler: Optional[Any] = None,
139+
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
141140
**kwargs,
142141
) -> None:
143142
"""
@@ -157,6 +156,7 @@ def fit( # pylint: disable=W0221
157156
and providing it takes no effect.
158157
"""
159158
import torch
159+
from torch.utils.data import TensorDataset, DataLoader
160160

161161
# Set model mode
162162
self._model.train(mode=training_mode)
@@ -172,36 +172,28 @@ def fit( # pylint: disable=W0221
172172
# Check label shape
173173
y_preprocessed = self.reduce_labels(y_preprocessed)
174174

175-
num_batch = len(x_preprocessed) / float(batch_size)
176-
if drop_last:
177-
num_batch = int(np.floor(num_batch))
178-
else:
179-
num_batch = int(np.ceil(num_batch))
180-
ind = np.arange(len(x_preprocessed))
181-
std = torch.tensor(self.scale).to(self._device)
182-
183-
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
184-
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
175+
# Create dataloader
176+
x_tensor = torch.from_numpy(x_preprocessed)
177+
y_tensor = torch.from_numpy(y_preprocessed)
178+
dataset = TensorDataset(x_tensor, y_tensor)
179+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
185180

186181
# Start training
187182
for _ in tqdm(range(nb_epochs)):
188-
# Shuffle the examples
189-
random.shuffle(ind)
190-
191-
# Train for one epoch
192-
for m in range(num_batch):
193-
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
194-
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
183+
for x_batch, y_batch in dataloader:
184+
# Move inputs to device
185+
x_batch = x_batch.to(self._device)
186+
y_batch = y_batch.to(self._device)
195187

196188
# Add random noise for randomized smoothing
197-
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
189+
x_batch += torch.randn_like(x_batch) * self.scale
198190

199191
# Zero the parameter gradients
200192
self._optimizer.zero_grad()
201193

202194
# Perform prediction
203195
try:
204-
model_outputs = self._model(i_batch)
196+
model_outputs = self._model(x_batch)
205197
except ValueError as err:
206198
if "Expected more than 1 value per channel when training" in str(err):
207199
logger.exception(
@@ -211,7 +203,7 @@ def fit( # pylint: disable=W0221
211203
raise err
212204

213205
# Form the loss function
214-
loss = self._loss(model_outputs[-1], o_batch)
206+
loss = self._loss(model_outputs[-1], y_batch)
215207

216208
# Do training
217209
if self._use_amp: # pragma: no cover

art/estimators/classification/pytorch.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import copy
2525
import logging
2626
import os
27-
import random
2827
import time
2928
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
3029

@@ -309,26 +308,27 @@ def predict( # pylint: disable=W0221
309308
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
310309
"""
311310
import torch
311+
from torch.utils.data import TensorDataset, DataLoader
312312

313313
# Set model mode
314314
self._model.train(mode=training_mode)
315315

316316
# Apply preprocessing
317317
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)
318318

319-
results_list = []
319+
# Create dataloader
320+
x_tensor = torch.from_numpy(x_preprocessed)
321+
dataset = TensorDataset(x_tensor)
322+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
320323

321-
# Run prediction with batch processing
322-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
323-
for m in range(num_batch):
324-
# Batch indexes
325-
begin, end = (
326-
m * batch_size,
327-
min((m + 1) * batch_size, x_preprocessed.shape[0]),
328-
)
324+
results_list = []
325+
for (x_batch,) in dataloader:
326+
# Move inputs to device
327+
x_batch = x_batch.to(self._device)
329328

329+
# Run prediction
330330
with torch.no_grad():
331-
model_outputs = self._model(torch.from_numpy(x_preprocessed[begin:end]).to(self._device))
331+
model_outputs = self._model(x_batch)
332332
output = model_outputs[-1]
333333
output = output.detach().cpu().numpy().astype(np.float32)
334334
if len(output.shape) == 1:
@@ -373,7 +373,7 @@ def fit( # pylint: disable=W0221
373373
nb_epochs: int = 10,
374374
training_mode: bool = True,
375375
drop_last: bool = False,
376-
scheduler: Optional[Any] = None,
376+
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
377377
**kwargs,
378378
) -> None:
379379
"""
@@ -393,6 +393,7 @@ def fit( # pylint: disable=W0221
393393
and providing it takes no effect.
394394
"""
395395
import torch
396+
from torch.utils.data import TensorDataset, DataLoader
396397

397398
# Set model mode
398399
self._model.train(mode=training_mode)
@@ -408,32 +409,25 @@ def fit( # pylint: disable=W0221
408409
# Check label shape
409410
y_preprocessed = self.reduce_labels(y_preprocessed)
410411

411-
num_batch = len(x_preprocessed) / float(batch_size)
412-
if drop_last:
413-
num_batch = int(np.floor(num_batch))
414-
else:
415-
num_batch = int(np.ceil(num_batch))
416-
ind = np.arange(len(x_preprocessed))
417-
418-
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
419-
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
412+
# Create dataloader
413+
x_tensor = torch.from_numpy(x_preprocessed)
414+
y_tensor = torch.from_numpy(y_preprocessed)
415+
dataset = TensorDataset(x_tensor, y_tensor)
416+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
420417

421418
# Start training
422419
for _ in range(nb_epochs):
423-
# Shuffle the examples
424-
random.shuffle(ind)
425-
426-
# Train for one epoch
427-
for m in range(num_batch):
428-
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
429-
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
420+
for x_batch, y_batch in dataloader:
421+
# Move inputs to device
422+
x_batch = x_batch.to(self._device)
423+
y_batch = y_batch.to(self._device)
430424

431425
# Zero the parameter gradients
432426
self._optimizer.zero_grad()
433427

434428
# Perform prediction
435429
try:
436-
model_outputs = self._model(i_batch)
430+
model_outputs = self._model(x_batch)
437431
except ValueError as err:
438432
if "Expected more than 1 value per channel when training" in str(err):
439433
logger.exception(
@@ -443,15 +437,14 @@ def fit( # pylint: disable=W0221
443437
raise err
444438

445439
# Form the loss function
446-
loss = self._loss(model_outputs[-1], o_batch)
440+
loss = self._loss(model_outputs[-1], y_batch)
447441

448442
# Do training
449443
if self._use_amp: # pragma: no cover
450444
from apex import amp # pylint: disable=E0611
451445

452446
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
453447
scaled_loss.backward()
454-
455448
else:
456449
loss.backward()
457450

0 commit comments

Comments
 (0)