Skip to content

Commit d787456

Browse files
fix for multiple types in trainer add_argparse_args (#3077)
* fix * fix * fix * fix * temp * fix * 0.9.0 readme * 0.9.0 readme * 0.9.0 readme Co-authored-by: William Falcon <[email protected]>
1 parent af8aceb commit d787456

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,16 @@ def use_type(x):
766766
use_type = arg_types[0]
767767

768768
if arg == 'gpus' or arg == 'tpu_cores':
769-
use_type = Trainer._allowed_type
770-
arg_default = Trainer._arg_default
769+
use_type = Trainer._gpus_allowed_type
770+
arg_default = Trainer._gpus_arg_default
771+
772+
# hack for types in (int, float)
773+
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
774+
use_type = Trainer._int_or_float_type
775+
776+
# hack for track_grad_norm
777+
if arg == 'track_grad_norm':
778+
use_type = float
771779

772780
parser.add_argument(
773781
f'--{arg}',
@@ -780,18 +788,24 @@ def use_type(x):
780788

781789
return parser
782790

783-
def _allowed_type(x) -> Union[int, str]:
791+
def _gpus_allowed_type(x) -> Union[int, str]:
784792
if ',' in x:
785793
return str(x)
786794
else:
787795
return int(x)
788796

789-
def _arg_default(x) -> Union[int, str]:
797+
def _gpus_arg_default(x) -> Union[int, str]:
790798
if ',' in x:
791799
return str(x)
792800
else:
793801
return int(x)
794802

803+
def _int_or_float_type(x) -> Union[int, float]:
804+
if '.' in str(x):
805+
return float(x)
806+
else:
807+
return int(x)
808+
795809
@classmethod
796810
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
797811
"""Parse CLI arguments, required for custom bool types."""

0 commit comments

Comments
 (0)