Skip to content

Commit be89eb0

Browse files
williamFalconBorda
andcommitted
cleaned docs, fixed argparse generator (#1075)
* Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Test deprecated API for 0.8.0 and 0.9.0 (#1071) * till 0.8 * refactor * fix tests * fix tests * deprx till 0.9 * Update trainer.py * Apply suggestions from code review Co-authored-by: William Falcon <[email protected]> * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test Co-authored-by: Jirka Borovec <[email protected]>
1 parent 9f140b7 commit be89eb0

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

docs/source/hyperparameters.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ modify the network. The `Trainer` can add all the available options to an Argume
1919
parser.add_argument('--layer_1_dim', type=int, default=128)
2020
parser.add_argument('--layer_2_dim', type=int, default=256)
2121
parser.add_argument('--batch_size', type=int, default=64)
22+
23+
# add all the available options to the trainer
24+
parser = pl.Trainer.add_argparse_args(parser)
25+
2226
args = parser.parse_args()
2327
2428
Now we can parametrize the LightningModule.

docs/source/introduction_guide.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,12 @@ modify the network. The `Trainer` can add all the available options to an Argume
588588
589589
# parametrize the network
590590
parser.add_argument('--layer_1_dim', type=int, default=128)
591-
parser.add_argument('--layer_1_dim', type=int, default=256)
591+
parser.add_argument('--layer_2_dim', type=int, default=256)
592592
parser.add_argument('--batch_size', type=int, default=64)
593+
594+
# add all the available options to the trainer
595+
parser = pl.Trainer.add_argparse_args(parser)
596+
593597
args = parser.parse_args()
594598
595599
Now we can parametrize the LightningModule.

pytorch_lightning/trainer/deprecated_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ def min_nb_epochs(self, min_epochs):
7575
@property
7676
def nb_sanity_val_steps(self):
7777
"""Back compatibility, will be removed in v0.8.0"""
78-
warnings.warn("Attribute `nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
78+
warnings.warn("Attribute `nb_sanity_val_steps` has renamed to "
79+
"`num_sanity_val_steps` since v0.5.0"
7980
" and this method will be removed in v0.8.0", DeprecationWarning)
8081
return self.num_sanity_val_steps
8182

8283
@nb_sanity_val_steps.setter
8384
def nb_sanity_val_steps(self, nb):
8485
"""Back compatibility, will be removed in v0.8.0"""
85-
warnings.warn("Attribute `nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
86+
warnings.warn("Attribute `nb_sanity_val_steps` has renamed to "
87+
"`num_sanity_val_steps` since v0.5.0"
8688
" and this method will be removed in v0.8.0", DeprecationWarning)
8789
self.num_sanity_val_steps = nb

pytorch_lightning/trainer/trainer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def __init__(
293293
self.num_sanity_val_steps = num_sanity_val_steps
294294
# Backward compatibility, TODO: remove in v0.8.0
295295
if nb_sanity_val_steps is not None:
296-
warnings.warn("Argument `nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
296+
warnings.warn("Argument `nb_sanity_val_steps` has renamed to "
297+
"`num_sanity_val_steps` since v0.5.0"
297298
" and this method will be removed in v0.8.0", DeprecationWarning)
298299
self.nb_sanity_val_steps = nb_sanity_val_steps
299300
self.print_nan_grads = print_nan_grads
@@ -437,17 +438,32 @@ def slurm_job_id(self) -> int:
437438

438439
@classmethod
439440
def default_attributes(cls):
440-
return vars(cls())
441+
import inspect
442+
443+
init_signature = inspect.signature(Trainer)
444+
445+
args = {}
446+
for param_name in init_signature.parameters:
447+
value = init_signature.parameters[param_name].default
448+
args[param_name] = value
449+
450+
return args
441451

442452
@classmethod
443453
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
444454
"""Extend existing argparse by default `Trainer` attributes."""
445-
parser = ArgumentParser(parents=[parent_parser])
455+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
446456

447457
trainer_default_params = Trainer.default_attributes()
448458

459+
# TODO: get "help" from docstring :)
449460
for arg in trainer_default_params:
450-
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg)
461+
parser.add_argument(
462+
f'--{arg}',
463+
default=trainer_default_params[arg],
464+
dest=arg,
465+
help='autogenerated by pl.Trainer'
466+
)
451467

452468
return parser
453469

0 commit comments

Comments
 (0)