File tree Expand file tree Collapse file tree 1 file changed +18
-4
lines changed
pytorch_lightning/trainer Expand file tree Collapse file tree 1 file changed +18
-4
lines changed Original file line number Diff line number Diff line change @@ -766,8 +766,16 @@ def use_type(x):
766766 use_type = arg_types [0 ]
767767
768768 if arg == 'gpus' or arg == 'tpu_cores' :
769- use_type = Trainer ._allowed_type
770- arg_default = Trainer ._arg_default
769+ use_type = Trainer ._gpus_allowed_type
770+ arg_default = Trainer ._gpus_arg_default
771+
772+ # hack for types in (int, float)
773+ if len (arg_types ) == 2 and int in set (arg_types ) and float in set (arg_types ):
774+ use_type = Trainer ._int_or_float_type
775+
776+ # hack for track_grad_norm
777+ if arg == 'track_grad_norm' :
778+ use_type = float
771779
772780 parser .add_argument (
773781 f'--{ arg } ' ,
@@ -780,18 +788,24 @@ def use_type(x):
780788
781789 return parser
782790
783- def _allowed_type (x ) -> Union [int , str ]:
791+ def _gpus_allowed_type (x ) -> Union [int , str ]:
784792 if ',' in x :
785793 return str (x )
786794 else :
787795 return int (x )
788796
789- def _arg_default (x ) -> Union [int , str ]:
797+ def _gpus_arg_default (x ) -> Union [int , str ]:
790798 if ',' in x :
791799 return str (x )
792800 else :
793801 return int (x )
794802
803+ def _int_or_float_type (x ) -> Union [int , float ]:
804+ if '.' in str (x ):
805+ return float (x )
806+ else :
807+ return int (x )
808+
795809 @classmethod
796810 def parse_argparser (cls , arg_parser : Union [ArgumentParser , Namespace ]) -> Namespace :
797811 """Parse CLI arguments, required for custom bool types."""
You can’t perform that action at this time.
0 commit comments