Skip to content

Commit b62f866

Browse files
committed
Update KerasClassifier for verbose argument
Signed-off-by: Beat Buesser <[email protected]>
1 parent b8607cf commit b62f866

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

art/estimators/classification/keras.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,9 @@ def predict( # pylint: disable=W0221
559559

560560
return predictions
561561

562-
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
562+
def fit(
563+
self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, verbose: bool = False, **kwargs
564+
) -> None:
563565
"""
564566
Fit the classifier on the training set `(x, y)`.
565567
@@ -568,6 +570,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
568570
shape (nb_samples,).
569571
:param batch_size: Size of batches.
570572
:param nb_epochs: Number of epochs to use for training.
573+
:param verbose: Display training progress bar.
571574
:param kwargs: Dictionary of framework-specific arguments. These should be parameters supported by the
572575
`fit_generator` function in Keras and will be passed to this function as such. Including the number of
573576
epochs or the number of steps per epoch as part of this argument will result in as error.
@@ -582,18 +585,18 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
582585
if self._reduce_labels or y_ndim == 1:
583586
y_preprocessed = np.argmax(y_preprocessed, axis=1)
584587

585-
if "verbose" in kwargs:
586-
kwargs["verbose"] = int(kwargs["verbose"])
587-
588-
self._model.fit(x=x_preprocessed, y=y_preprocessed, batch_size=batch_size, epochs=nb_epochs, **kwargs)
588+
self._model.fit(
589+
x=x_preprocessed, y=y_preprocessed, batch_size=batch_size, epochs=nb_epochs, verbose=int(verbose), **kwargs
590+
)
589591

590-
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
592+
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, verbose: bool = False, **kwargs) -> None:
591593
"""
592594
Fit the classifier using the generator that yields batches as specified.
593595
594596
:param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
595597
training in Keras, it will.
596598
:param nb_epochs: Number of epochs to use for training.
599+
:param verbose: Display training progress bar.
597600
:param kwargs: Dictionary of framework-specific arguments. These should be parameters supported by the
598601
`fit_generator` function in Keras and will be passed to this function as such. Including the number of
599602
epochs as part of this argument will result in as error.
@@ -603,9 +606,6 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
603606
# Try to use the generator as a Keras native generator, otherwise use it through the `DataGenerator` interface
604607
from art.preprocessing.standardisation_mean_std.numpy import StandardisationMeanStd
605608

606-
if "verbose" in kwargs:
607-
kwargs["verbose"] = int(kwargs["verbose"])
608-
609609
if isinstance(generator, KerasDataGenerator) and (
610610
self.preprocessing is None
611611
or (
@@ -618,12 +618,12 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
618618
)
619619
):
620620
try:
621-
self._model.fit_generator(generator.iterator, epochs=nb_epochs, **kwargs)
621+
self._model.fit_generator(generator.iterator, epochs=nb_epochs, verbose=int(verbose), **kwargs)
622622
except ValueError: # pragma: no cover
623623
logger.info("Unable to use data generator as Keras generator. Now treating as framework-independent.")
624-
super().fit_generator(generator, nb_epochs=nb_epochs, **kwargs)
624+
super().fit_generator(generator, nb_epochs=nb_epochs, verbose=verbose, **kwargs)
625625
else: # pragma: no cover
626-
super().fit_generator(generator, nb_epochs=nb_epochs, **kwargs)
626+
super().fit_generator(generator, nb_epochs=nb_epochs, verbose=verbose, **kwargs)
627627

628628
def get_activations(
629629
self, x: np.ndarray, layer: Union[int, str], batch_size: int = 128, framework: bool = False

art/estimators/classification/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def fit( # pylint: disable=W0221
391391
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
392392
the last batch will be smaller. (default: ``False``)
393393
:param scheduler: Learning rate scheduler to run at the start of every epoch.
394-
:param verbose: If to display the progress bar information.
394+
:param verbose: Display training progress bar.
395395
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
396396
and providing it takes no effect.
397397
"""

art/estimators/keras.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs):
6161
:return: Predictions.
6262
:rtype: Format as expected by the `model`
6363
"""
64-
if "verbose" in kwargs:
65-
kwargs["verbose"] = int(kwargs["verbose"])
6664
return NeuralNetworkMixin.predict(self, x, batch_size=batch_size, **kwargs)
6765

6866
def fit(self, x: np.ndarray, y, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
@@ -76,8 +74,6 @@ def fit(self, x: np.ndarray, y, batch_size: int = 128, nb_epochs: int = 20, **kw
7674
:param batch_size: Batch size.
7775
:param nb_epochs: Number of training epochs.
7876
"""
79-
if "verbose" in kwargs:
80-
kwargs["verbose"] = int(kwargs["verbose"])
8177
NeuralNetworkMixin.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
8278

8379
def compute_loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:

0 commit comments

Comments
 (0)