Skip to content

Commit bc60298

Browse files
committed
fixes #739
1 parent 75c87b8 commit bc60298

File tree

3 files changed

+182
-44
lines changed

3 files changed

+182
-44
lines changed

fastcore/_modidx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,9 @@
545545
'fastcore.script._HelpFormatter.__init__': ('script.html#_helpformatter.__init__', 'fastcore/script.py'),
546546
'fastcore.script._HelpFormatter._expand_help': ( 'script.html#_helpformatter._expand_help',
547547
'fastcore/script.py'),
548+
'fastcore.script._is_union': ('script.html#_is_union', 'fastcore/script.py'),
549+
'fastcore.script._union_parser': ('script.html#_union_parser', 'fastcore/script.py'),
550+
'fastcore.script._union_type': ('script.html#_union_type', 'fastcore/script.py'),
548551
'fastcore.script.anno_parser': ('script.html#anno_parser', 'fastcore/script.py'),
549552
'fastcore.script.args_from_prog': ('script.html#args_from_prog', 'fastcore/script.py'),
550553
'fastcore.script.bool_arg': ('script.html#bool_arg', 'fastcore/script.py'),

fastcore/script.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
'call_parse']
88

99
# %% ../nbs/06_script.ipynb
10-
import inspect,argparse,shutil
10+
import inspect,argparse,shutil,types
11+
1112
from functools import wraps,partial
1213
from .imports import *
1314
from .utils import *
1415
from .docments import docments
16+
from typing import get_origin, get_args, Union
1517

1618
# %% ../nbs/06_script.ipynb
1719
def store_true():
@@ -40,28 +42,31 @@ class Param:
4042
"A parameter in a function used in `anno_parser` or `call_parse`"
4143
def __init__(self, help="", type=None, opt=True, action=None, nargs=None, const=None,
4244
choices=None, required=None, default=None, version=None):
43-
if type in (store_true,bool): type,action,default=None,'store_true',False
45+
if type==store_true: type,action,default=None,'store_true',False
4446
if type==store_false: type,action,default=None,'store_false',True
4547
if type and isinstance(type,typing.Type) and issubclass(type,enum.Enum) and not choices: choices=list(type)
4648
help = help or ""
49+
self.negated = False
4750
store_attr()
4851

4952
def set_default(self, d):
5053
if self.action == "version":
5154
if d != inspect.Parameter.empty: self.version = d
5255
self.opt = True
5356
return
54-
if self.default is None:
57+
if self.type in (bool, bool_arg) and self.action is None:
58+
self.type = None
59+
if d is True: self.action,self.default,self.negated = 'store_false',True,True
60+
else: self.action,self.default = 'store_true',False
61+
elif self.default is None:
5562
if d == inspect.Parameter.empty: self.opt = False
5663
else: self.default = d
57-
if self.default is not None:
58-
self.help += f" (default: {self.default})"
64+
if self.default is not None: self.help += f" (default: {self.default})"
5965

6066
@property
6167
def pre(self): return '--' if self.opt else ''
6268
@property
63-
def kwargs(self): return {k:v for k,v in self.__dict__.items()
64-
if v is not None and k!='opt' and k[0]!='_'}
69+
def kwargs(self): return {k:v for k,v in self.__dict__.items() if v is not None and k not in ('opt','negated') and k[0]!='_'}
6570
def __repr__(self):
6671
if not self.help and self.type is None: return ""
6772
if not self.help and self.type is not None: return f"{clean_type_str(self.type)}"
@@ -76,15 +81,37 @@ def __init__(self, prog, indent_increment=2):
7681
def _expand_help(self, action): return self._get_help_string(action)
7782

7883
# %% ../nbs/06_script.ipynb
79-
def anno_parser(func, # Function to get arguments from
80-
prog:str=None): # The name of the program
84+
def _is_union(t): return get_origin(t) in (Union, types.UnionType) if hasattr(types, 'UnionType') else get_origin(t) is Union
85+
86+
def _union_parser(types):
87+
"Return a parser that tries each type in sequence"
88+
def _parse(v):
89+
for t in types:
90+
if t is type(None): continue
91+
try: return t(v)
92+
except: pass
93+
raise ValueError(f"Could not parse {v!r} as any of {types}")
94+
return _parse
95+
96+
def _union_type(t):
97+
"Get parser for Union types, or None if not a Union"
98+
if not _is_union(t): return None
99+
return _union_parser(get_args(t))
100+
101+
# %% ../nbs/06_script.ipynb
102+
def anno_parser(func, prog:str=None):
81103
"Look at params (annotated with `Param`) in func and return an `ArgumentParser`"
82104
p = argparse.ArgumentParser(description=func.__doc__, prog=prog, formatter_class=_HelpFormatter)
83105
for k,v in docments(func, full=True, returns=False, eval_str=True).items():
84106
param = v.anno
85-
if not isinstance(param,Param): param = Param(v.docment, v.anno)
107+
if not isinstance(param, Param):
108+
anno = _union_type(v.anno) or v.anno
109+
param = Param(v.docment, anno)
86110
param.set_default(v.default)
87-
p.add_argument(f"{param.pre}{k}", **param.kwargs)
111+
name = f"no-{k}" if param.negated else k
112+
kw = param.kwargs
113+
if param.negated: kw['dest'] = k
114+
p.add_argument(f"{param.pre}{name}", **kw)
88115
p.add_argument(f"--pdb", help=argparse.SUPPRESS, action='store_true')
89116
p.add_argument(f"--xtra", help=argparse.SUPPRESS, type=str)
90117
return p
@@ -98,7 +125,10 @@ def args_from_prog(func, prog):
98125
args = {progsp[i]:progsp[i+1] for i in range(0, len(progsp), 2)}
99126
annos = type_hints(func)
100127
for k,v in args.items():
101-
t = annos.get(k, Param()).type
128+
anno = annos.get(k)
129+
t = getattr(anno, 'type', anno)
130+
if t in (bool, bool_arg): t = str2bool
131+
elif isinstance(anno, Param) and anno.action in ('store_true', 'store_false'): t = str2bool
102132
if t: args[k] = t(v)
103133
return args
104134

0 commit comments

Comments
 (0)