Skip to content

Commit 50e0be1

Browse files
committed
use torch dataloader for randomized smoothing
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 67e89b3 commit 50e0be1

File tree

1 file changed

+14
-21
lines changed
  • art/estimators/certification/randomized_smoothing

1 file changed

+14
-21
lines changed

art/estimators/certification/randomized_smoothing/pytorch.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def fit( # pylint: disable=W0221
137137
nb_epochs: int = 10,
138138
training_mode: bool = True,
139139
drop_last: bool = False,
140-
scheduler: Optional[Any] = None,
140+
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
141141
**kwargs,
142142
) -> None:
143143
"""
@@ -157,6 +157,7 @@ def fit( # pylint: disable=W0221
157157
and providing it takes no effect.
158158
"""
159159
import torch
160+
from torch.utils.data import TensorDataset, DataLoader
160161

161162
# Set model mode
162163
self._model.train(mode=training_mode)
@@ -172,36 +173,28 @@ def fit( # pylint: disable=W0221
172173
# Check label shape
173174
y_preprocessed = self.reduce_labels(y_preprocessed)
174175

175-
num_batch = len(x_preprocessed) / float(batch_size)
176-
if drop_last:
177-
num_batch = int(np.floor(num_batch))
178-
else:
179-
num_batch = int(np.ceil(num_batch))
180-
ind = np.arange(len(x_preprocessed))
181-
std = torch.tensor(self.scale).to(self._device)
182-
183-
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
184-
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
176+
# Create dataloader
177+
x_tensor = torch.from_numpy(x_preprocessed)
178+
y_tensor = torch.from_numpy(y_preprocessed)
179+
dataset = TensorDataset(x_tensor, y_tensor)
180+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
185181

186182
# Start training
187183
for _ in tqdm(range(nb_epochs)):
188-
# Shuffle the examples
189-
random.shuffle(ind)
190-
191-
# Train for one epoch
192-
for m in range(num_batch):
193-
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
194-
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
184+
for x_batch, y_batch in dataloader:
185+
# Move inputs to device
186+
x_batch = x_batch.to(self._device)
187+
y_batch = y_batch.to(self._device)
195188

196189
# Add random noise for randomized smoothing
197-
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
190+
x_batch += torch.randn_like(x_batch) * self.scale
198191

199192
# Zero the parameter gradients
200193
self._optimizer.zero_grad()
201194

202195
# Perform prediction
203196
try:
204-
model_outputs = self._model(i_batch)
197+
model_outputs = self._model(x_batch)
205198
except ValueError as err:
206199
if "Expected more than 1 value per channel when training" in str(err):
207200
logger.exception(
@@ -211,7 +204,7 @@ def fit( # pylint: disable=W0221
211204
raise err
212205

213206
# Form the loss function
214-
loss = self._loss(model_outputs[-1], o_batch)
207+
loss = self._loss(model_outputs[-1], y_batch)
215208

216209
# Do training
217210
if self._use_amp: # pragma: no cover

0 commit comments

Comments
 (0)