Skip to content

Commit 804af32

Browse files
committed
Better handling of int values for verbosity. Adding fit_gen to tests for kwargs checks
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 9c3f572 commit 804af32

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

art/estimators/classification/pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,20 +371,20 @@ def process_verbose(self, verbose: Optional[Union[bool, int]] = None) -> bool:
371371
Function to unify the various ways implemented in ART of displaying progress bars
372372
into a single True/False output.
373373
374-
:param verbose: If to display the progress bar information.
374+
:param verbose: If to display the progress bar information in one of a few possible formats.
375375
:return: True/False if to display the progress bars.
376376
"""
377377

378378
if verbose is not None:
379379
if isinstance(verbose, int):
380-
if verbose == 0:
380+
if verbose <= 0:
381381
display_pb = False
382382
else:
383383
display_pb = True
384384
elif isinstance(verbose, bool):
385385
display_pb = verbose
386386
else:
387-
raise ValueError("Verbose should be True/False or a 0/1 int")
387+
raise ValueError("Verbose should be True/False or an int")
388388
else:
389389
# Check if the verbose attribute is present in the current classifier
390390
if hasattr(self, "verbose"):

art/estimators/classification/tensorflow.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def process_verbose(self, verbose: Optional[Union[bool, int]] = None) -> bool:
270270
"""
271271
Function to unify the various ways implemented in ART of displaying progress bars
272272
into a single True/False output.
273-
:param verbose: If to display the progress bar information.
273+
:param verbose: If to display the progress bar information in one of a few possible formats.
274274
:return: True/False if to display the progress bars.
275275
"""
276276

@@ -1007,13 +1007,14 @@ def process_verbose(self, verbose: Optional[Union[bool, int]] = None) -> bool:
10071007
"""
10081008
Function to unify the various ways implemented in ART of displaying progress bars
10091009
into a single True/False output.
1010-
:param verbose: If to display the progress bar information.
1010+
1011+
:param verbose: If to display the progress bar information in one of a few possible formats.
10111012
:return: True/False if to display the progress bars.
10121013
"""
10131014

10141015
if verbose is not None:
10151016
if isinstance(verbose, int):
1016-
if verbose == 0:
1017+
if verbose <= 0:
10171018
display_pb = False
10181019
else:
10191020
display_pb = True

tests/estimators/classification/test_deeplearning_common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,25 @@ def test_functional_model(art_warning, image_dl_estimator):
192192

193193

194194
@pytest.mark.skip_framework("mxnet", "non_dl_frameworks")
195-
def test_fit_kwargs(art_warning, image_dl_estimator, get_default_mnist_subset, default_batch_size, framework):
195+
def test_fit_kwargs(
196+
art_warning, image_dl_estimator, get_default_mnist_subset, image_data_generator, default_batch_size, framework
197+
):
196198
try:
197199
(x_train_mnist, y_train_mnist), (_, _) = get_default_mnist_subset
198200

199201
def get_lr(_):
200202
return 0.01
201203

202204
# Test a valid callback
203-
classifier, _ = image_dl_estimator(from_logits=True)
205+
classifier, sess = image_dl_estimator(from_logits=True)
204206

205207
kwargs = {"callbacks": [LearningRateScheduler(get_lr)], "verbose": True}
206208
classifier.fit(x_train_mnist, y_train_mnist, batch_size=default_batch_size, nb_epochs=1, **kwargs)
207209

210+
# Check for fit_generator kwargs as well
211+
data_gen = image_data_generator(sess=sess)
212+
classifier.fit_generator(generator=data_gen, nb_epochs=1, **kwargs)
213+
208214
# Test failure for invalid parameters: does not apply to many frameworks which allow arbitrary kwargs
209215
if framework not in [
210216
"tensorflow1",

0 commit comments

Comments
 (0)