Skip to content

Commit fc12c05

Browse files
author
Beat Buesser
committed
Update missing parameter application
Signed-off-by: Beat Buesser <[email protected]>
1 parent 5aefcc7 commit fc12c05

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def fit( # pylint: disable=W0221
155155
batch_size: int = 128,
156156
nb_epochs: int = 10,
157157
training_mode: bool = True,
158+
drop_last: bool = False,
158159
scheduler: Optional[Any] = None,
159160
**kwargs,
160161
) -> None:
@@ -166,6 +167,9 @@ def fit( # pylint: disable=W0221
166167
:param batch_size: Size of batches.
167168
:param nb_epochs: Number of epochs to use for training.
168169
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
170+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
171+
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
172+
the last batch will be smaller. (default: ``False``)
169173
:param scheduler: Learning rate scheduler to run at the start of every epoch.
170174
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
171175
and providing it takes no effect.
@@ -186,7 +190,11 @@ def fit( # pylint: disable=W0221
186190
# Check label shape
187191
y_preprocessed = self.reduce_labels(y_preprocessed)
188192

189-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
193+
num_batch = len(x_preprocessed) / float(batch_size)
194+
if drop_last:
195+
num_batch = int(np.floor(num_batch))
196+
else:
197+
num_batch = int(np.ceil(num_batch))
190198
ind = np.arange(len(x_preprocessed))
191199

192200
# Start training

art/estimators/certification/randomized_smoothing/pytorch.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def fit( # pylint: disable=W0221
136136
batch_size: int = 128,
137137
nb_epochs: int = 10,
138138
training_mode: bool = True,
139+
drop_last: bool = False,
139140
scheduler: Optional[Any] = None,
140141
**kwargs,
141142
) -> None:
@@ -148,6 +149,9 @@ def fit( # pylint: disable=W0221
148149
:param batch_size: Size of batches.
149150
:param nb_epochs: Number of epochs to use for training.
150151
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
152+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
153+
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
154+
the last batch will be smaller. (default: ``False``)
151155
:param scheduler: Learning rate scheduler to run at the start of every epoch.
152156
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
153157
and providing it takes no effect.
@@ -168,7 +172,11 @@ def fit( # pylint: disable=W0221
168172
# Check label shape
169173
y_preprocessed = self.reduce_labels(y_preprocessed)
170174

171-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
175+
num_batch = len(x_preprocessed) / float(batch_size)
176+
if drop_last:
177+
num_batch = int(np.floor(num_batch))
178+
else:
179+
num_batch = int(np.ceil(num_batch))
172180
ind = np.arange(len(x_preprocessed))
173181
std = torch.tensor(self.scale).to(self._device)
174182

@@ -217,6 +225,9 @@ def fit( # pylint: disable=W0221
217225

218226
self._optimizer.step()
219227

228+
if scheduler is not None:
229+
scheduler.step()
230+
220231
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
221232
"""
222233
Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.

0 commit comments

Comments
 (0)