77 'call_parse' ]
88
99# %% ../nbs/06_script.ipynb
10- import inspect ,argparse ,shutil
10+ import inspect ,argparse ,shutil ,types
11+
1112from functools import wraps ,partial
1213from .imports import *
1314from .utils import *
1415from .docments import docments
16+ from typing import get_origin , get_args , Union
1517
1618# %% ../nbs/06_script.ipynb
1719def store_true ():
@@ -40,28 +42,31 @@ class Param:
4042 "A parameter in a function used in `anno_parser` or `call_parse`"
4143 def __init__ (self , help = "" , type = None , opt = True , action = None , nargs = None , const = None ,
4244 choices = None , required = None , default = None , version = None ):
43- if type in ( store_true , bool ) : type ,action ,default = None ,'store_true' ,False
45+ if type == store_true : type ,action ,default = None ,'store_true' ,False
4446 if type == store_false : type ,action ,default = None ,'store_false' ,True
4547 if type and isinstance (type ,typing .Type ) and issubclass (type ,enum .Enum ) and not choices : choices = list (type )
4648 help = help or ""
49+ self .negated = False
4750 store_attr ()
4851
4952 def set_default (self , d ):
5053 if self .action == "version" :
5154 if d != inspect .Parameter .empty : self .version = d
5255 self .opt = True
5356 return
54- if self .default is None :
57+ if self .type in (bool , bool_arg ) and self .action is None :
58+ self .type = None
59+ if d is True : self .action ,self .default ,self .negated = 'store_false' ,True ,True
60+ else : self .action ,self .default = 'store_true' ,False
61+ elif self .default is None :
5562 if d == inspect .Parameter .empty : self .opt = False
5663 else : self .default = d
57- if self .default is not None :
58- self .help += f" (default: { self .default } )"
64+ if self .default is not None : self .help += f" (default: { self .default } )"
5965
6066 @property
6167 def pre (self ): return '--' if self .opt else ''
6268 @property
63- def kwargs (self ): return {k :v for k ,v in self .__dict__ .items ()
64- if v is not None and k != 'opt' and k [0 ]!= '_' }
69+ def kwargs (self ): return {k :v for k ,v in self .__dict__ .items () if v is not None and k not in ('opt' ,'negated' ) and k [0 ]!= '_' }
6570 def __repr__ (self ):
6671 if not self .help and self .type is None : return ""
6772 if not self .help and self .type is not None : return f"{ clean_type_str (self .type )} "
@@ -76,15 +81,37 @@ def __init__(self, prog, indent_increment=2):
7681 def _expand_help (self , action ): return self ._get_help_string (action )
7782
7883# %% ../nbs/06_script.ipynb
79- def anno_parser (func , # Function to get arguments from
80- prog :str = None ): # The name of the program
84+ def _is_union (t ): return get_origin (t ) in (Union , types .UnionType ) if hasattr (types , 'UnionType' ) else get_origin (t ) is Union
85+
86+ def _union_parser (types ):
87+ "Return a parser that tries each type in sequence"
88+ def _parse (v ):
89+ for t in types :
90+ if t is type (None ): continue
91+ try : return t (v )
92+ except : pass
93+ raise ValueError (f"Could not parse { v !r} as any of { types } " )
94+ return _parse
95+
96+ def _union_type (t ):
97+ "Get parser for Union types, or None if not a Union"
98+ if not _is_union (t ): return None
99+ return _union_parser (get_args (t ))
100+
101+ # %% ../nbs/06_script.ipynb
102+ def anno_parser (func , prog :str = None ):
81103 "Look at params (annotated with `Param`) in func and return an `ArgumentParser`"
82104 p = argparse .ArgumentParser (description = func .__doc__ , prog = prog , formatter_class = _HelpFormatter )
83105 for k ,v in docments (func , full = True , returns = False , eval_str = True ).items ():
84106 param = v .anno
85- if not isinstance (param ,Param ): param = Param (v .docment , v .anno )
107+ if not isinstance (param , Param ):
108+ anno = _union_type (v .anno ) or v .anno
109+ param = Param (v .docment , anno )
86110 param .set_default (v .default )
87- p .add_argument (f"{ param .pre } { k } " , ** param .kwargs )
111+ name = f"no-{ k } " if param .negated else k
112+ kw = param .kwargs
113+ if param .negated : kw ['dest' ] = k
114+ p .add_argument (f"{ param .pre } { name } " , ** kw )
88115 p .add_argument (f"--pdb" , help = argparse .SUPPRESS , action = 'store_true' )
89116 p .add_argument (f"--xtra" , help = argparse .SUPPRESS , type = str )
90117 return p
@@ -98,7 +125,10 @@ def args_from_prog(func, prog):
98125 args = {progsp [i ]:progsp [i + 1 ] for i in range (0 , len (progsp ), 2 )}
99126 annos = type_hints (func )
100127 for k ,v in args .items ():
101- t = annos .get (k , Param ()).type
128+ anno = annos .get (k )
129+ t = getattr (anno , 'type' , anno )
130+ if t in (bool , bool_arg ): t = str2bool
131+ elif isinstance (anno , Param ) and anno .action in ('store_true' , 'store_false' ): t = str2bool
102132 if t : args [k ] = t (v )
103133 return args
104134
0 commit comments