Skip to content

Commit d303e95

Browse files
authored
Merge pull request #1883 from Trusted-AI/development_issue_1723
Add drop_last option to method fit of PyTorchClassifier
2 parents 4e226f2 + fc12c05 commit d303e95

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,21 @@ def fit( # pylint: disable=W0221
155155
batch_size: int = 128,
156156
nb_epochs: int = 10,
157157
training_mode: bool = True,
158+
drop_last: bool = False,
158159
scheduler: Optional[Any] = None,
159160
**kwargs,
160161
) -> None:
161162
"""
162163
Fit the classifier on the training set `(x, y)`.
163-
164164
:param x: Training data.
165165
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
166166
shape (nb_samples,).
167167
:param batch_size: Size of batches.
168168
:param nb_epochs: Number of epochs to use for training.
169169
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
170+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
171+
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
172+
the last batch will be smaller. (default: ``False``)
170173
:param scheduler: Learning rate scheduler to run at the start of every epoch.
171174
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
172175
and providing it takes no effect.
@@ -187,7 +190,11 @@ def fit( # pylint: disable=W0221
187190
# Check label shape
188191
y_preprocessed = self.reduce_labels(y_preprocessed)
189192

190-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
193+
num_batch = len(x_preprocessed) / float(batch_size)
194+
if drop_last:
195+
num_batch = int(np.floor(num_batch))
196+
else:
197+
num_batch = int(np.ceil(num_batch))
191198
ind = np.arange(len(x_preprocessed))
192199

193200
# Start training
@@ -207,7 +214,15 @@ def fit( # pylint: disable=W0221
207214
self._optimizer.zero_grad()
208215

209216
# Perform prediction
210-
model_outputs = self._model(i_batch)
217+
try:
218+
model_outputs = self._model(i_batch)
219+
except ValueError as err:
220+
if "Expected more than 1 value per channel when training" in str(err):
221+
logger.exception(
222+
"Try dropping the last incomplete batch by setting drop_last=True in "
223+
"method PyTorchClassifier.fit."
224+
)
225+
raise err
211226

212227
# Form the loss function
213228
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]

art/estimators/certification/randomized_smoothing/pytorch.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
26+
from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING
2727

2828
import warnings
2929
import random
@@ -136,6 +136,8 @@ def fit( # pylint: disable=W0221
136136
batch_size: int = 128,
137137
nb_epochs: int = 10,
138138
training_mode: bool = True,
139+
drop_last: bool = False,
140+
scheduler: Optional[Any] = None,
139141
**kwargs,
140142
) -> None:
141143
"""
@@ -147,6 +149,10 @@ def fit( # pylint: disable=W0221
147149
:param batch_size: Size of batches.
148150
:param nb_epochs: Number of epochs to use for training.
149151
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
152+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
153+
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
154+
the last batch will be smaller. (default: ``False``)
155+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
150156
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
151157
and providing it takes no effect.
152158
"""
@@ -166,18 +172,26 @@ def fit( # pylint: disable=W0221
166172
# Check label shape
167173
y_preprocessed = self.reduce_labels(y_preprocessed)
168174

169-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
175+
num_batch = len(x_preprocessed) / float(batch_size)
176+
if drop_last:
177+
num_batch = int(np.floor(num_batch))
178+
else:
179+
num_batch = int(np.ceil(num_batch))
170180
ind = np.arange(len(x_preprocessed))
171181
std = torch.tensor(self.scale).to(self._device)
182+
183+
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
184+
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
185+
172186
# Start training
173187
for _ in tqdm(range(nb_epochs)):
174188
# Shuffle the examples
175189
random.shuffle(ind)
176190

177191
# Train for one epoch
178192
for m in range(num_batch):
179-
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
180-
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
193+
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
194+
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
181195

182196
# Add random noise for randomized smoothing
183197
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
@@ -186,7 +200,15 @@ def fit( # pylint: disable=W0221
186200
self._optimizer.zero_grad()
187201

188202
# Perform prediction
189-
model_outputs = self._model(i_batch)
203+
try:
204+
model_outputs = self._model(i_batch)
205+
except ValueError as err:
206+
if "Expected more than 1 value per channel when training" in str(err):
207+
logger.exception(
208+
"Try dropping the last incomplete batch by setting drop_last=True in "
209+
"method PyTorchClassifier.fit."
210+
)
211+
raise err
190212

191213
# Form the loss function
192214
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
@@ -203,6 +225,9 @@ def fit( # pylint: disable=W0221
203225

204226
self._optimizer.step()
205227

228+
if scheduler is not None:
229+
scheduler.step()
230+
206231
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
207232
"""
208233
Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.

art/estimators/classification/pytorch.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ def fit( # pylint: disable=W0221
362362
batch_size: int = 128,
363363
nb_epochs: int = 10,
364364
training_mode: bool = True,
365+
drop_last: bool = False,
366+
scheduler: Optional[Any] = None,
365367
**kwargs,
366368
) -> None:
367369
"""
@@ -373,8 +375,12 @@ def fit( # pylint: disable=W0221
373375
:param batch_size: Size of batches.
374376
:param nb_epochs: Number of epochs to use for training.
375377
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
378+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
379+
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
380+
the last batch will be smaller. (default: ``False``)
381+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
376382
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
377-
and providing it takes no effect.
383+
and providing it takes no effect.
378384
"""
379385
import torch # lgtm [py/repeated-import]
380386

@@ -392,24 +398,39 @@ def fit( # pylint: disable=W0221
392398
# Check label shape
393399
y_preprocessed = self.reduce_labels(y_preprocessed)
394400

395-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
401+
num_batch = len(x_preprocessed) / float(batch_size)
402+
if drop_last:
403+
num_batch = int(np.floor(num_batch))
404+
else:
405+
num_batch = int(np.ceil(num_batch))
396406
ind = np.arange(len(x_preprocessed))
397407

408+
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
409+
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
410+
398411
# Start training
399412
for _ in range(nb_epochs):
400413
# Shuffle the examples
401414
random.shuffle(ind)
402415

403416
# Train for one epoch
404417
for m in range(num_batch):
405-
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
406-
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
418+
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
419+
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
407420

408421
# Zero the parameter gradients
409422
self._optimizer.zero_grad()
410423

411424
# Perform prediction
412-
model_outputs = self._model(i_batch)
425+
try:
426+
model_outputs = self._model(i_batch)
427+
except ValueError as err:
428+
if "Expected more than 1 value per channel when training" in str(err):
429+
logger.exception(
430+
"Try dropping the last incomplete batch by setting drop_last=True in "
431+
"method PyTorchClassifier.fit."
432+
)
433+
raise err
413434

414435
# Form the loss function
415436
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
@@ -426,6 +447,9 @@ def fit( # pylint: disable=W0221
426447

427448
self._optimizer.step()
428449

450+
if scheduler is not None:
451+
scheduler.step()
452+
429453
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
430454
"""
431455
Fit the classifier using the generator that yields batches as specified.

0 commit comments

Comments
 (0)