Skip to content

Commit 2232bcc

Browse files
authored
Handle URLs for command line inputs (#1613)
* Handle URLs for command line inputs Also handle edge case of an input parameter type is wrapped in single-item list.
1 parent 0939a1d commit 2232bcc

File tree

4 files changed

+105
-23
lines changed

4 files changed

+105
-23
lines changed

cwltool/argparser.py

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
from typing import (
66
Any,
77
AnyStr,
8+
Callable,
89
Dict,
910
List,
1011
MutableMapping,
1112
MutableSequence,
1213
Optional,
1314
Sequence,
15+
Type,
1416
Union,
1517
cast,
1618
)
19+
import urllib
1720

1821
from schema_salad.ref_resolver import file_uri
1922

@@ -716,26 +719,34 @@ class FSAction(argparse.Action):
716719
objclass = None # type: str
717720

718721
def __init__(
719-
self, option_strings: List[str], dest: str, nargs: Any = None, **kwargs: Any
722+
self,
723+
option_strings: List[str],
724+
dest: str,
725+
nargs: Any = None,
726+
urljoin: Callable[[str, str], str] = urllib.parse.urljoin,
727+
base_uri: str = "",
728+
**kwargs: Any,
720729
) -> None:
721730
"""Fail if nargs is used."""
722731
if nargs is not None:
723732
raise ValueError("nargs not allowed")
733+
self.urljoin = urljoin
734+
self.base_uri = base_uri
724735
super().__init__(option_strings, dest, **kwargs)
725736

726737
def __call__(
727738
self,
728739
parser: argparse.ArgumentParser,
729740
namespace: argparse.Namespace,
730-
values: Union[AnyStr, Sequence[Any], None],
741+
values: Union[str, Sequence[Any], None],
731742
option_string: Optional[str] = None,
732743
) -> None:
733744
setattr(
734745
namespace,
735746
self.dest,
736747
{
737748
"class": self.objclass,
738-
"location": file_uri(str(os.path.abspath(cast(AnyStr, values)))),
749+
"location": self.urljoin(self.base_uri, cast(str, values)),
739750
},
740751
)
741752

@@ -744,18 +755,26 @@ class FSAppendAction(argparse.Action):
744755
objclass = None # type: str
745756

746757
def __init__(
747-
self, option_strings: List[str], dest: str, nargs: Any = None, **kwargs: Any
758+
self,
759+
option_strings: List[str],
760+
dest: str,
761+
nargs: Any = None,
762+
urljoin: Callable[[str, str], str] = urllib.parse.urljoin,
763+
base_uri: str = "",
764+
**kwargs: Any,
748765
) -> None:
749766
"""Initialize."""
750767
if nargs is not None:
751768
raise ValueError("nargs not allowed")
769+
self.urljoin = urljoin
770+
self.base_uri = base_uri
752771
super().__init__(option_strings, dest, **kwargs)
753772

754773
def __call__(
755774
self,
756775
parser: argparse.ArgumentParser,
757776
namespace: argparse.Namespace,
758-
values: Union[AnyStr, Sequence[Any], None],
777+
values: Union[str, Sequence[Any], None],
759778
option_string: Optional[str] = None,
760779
) -> None:
761780
g = getattr(namespace, self.dest)
@@ -765,7 +784,7 @@ def __call__(
765784
g.append(
766785
{
767786
"class": self.objclass,
768-
"location": file_uri(str(os.path.abspath(cast(AnyStr, values)))),
787+
"location": self.urljoin(self.base_uri, cast(str, values)),
769788
}
770789
)
771790

@@ -794,6 +813,8 @@ def add_argument(
794813
description: str = "",
795814
default: Any = None,
796815
input_required: bool = True,
816+
urljoin: Callable[[str, str], str] = urllib.parse.urljoin,
817+
base_uri: str = "",
797818
) -> None:
798819
if len(name) == 1:
799820
flag = "-"
@@ -804,27 +825,32 @@ def add_argument(
804825
# parameter required.
805826
required = default is None and input_required
806827
if isinstance(inptype, MutableSequence):
807-
if inptype[0] == "null":
828+
if len(inptype) == 1:
829+
inptype = inptype[0]
830+
elif len(inptype) == 2 and inptype[0] == "null":
831+
required = False
832+
inptype = inptype[1]
833+
elif len(inptype) == 2 and inptype[1] == "null":
808834
required = False
809-
if len(inptype) == 2:
810-
inptype = inptype[1]
811-
else:
812-
_logger.debug("Can't make command line argument from %s", inptype)
813-
return None
835+
inptype = inptype[0]
836+
else:
837+
_logger.debug("Can't make command line argument from %s", inptype)
838+
return None
814839

815840
ahelp = description.replace("%", "%%")
816-
action = None # type: Optional[Union[argparse.Action, str]]
841+
action = None # type: Optional[Union[Type[argparse.Action], str]]
817842
atype = None # type: Any
843+
typekw = {} # type: Dict[str, Any]
818844

819845
if inptype == "File":
820-
action = cast(argparse.Action, FileAction)
846+
action = FileAction
821847
elif inptype == "Directory":
822-
action = cast(argparse.Action, DirectoryAction)
848+
action = DirectoryAction
823849
elif isinstance(inptype, MutableMapping) and inptype["type"] == "array":
824850
if inptype["items"] == "File":
825-
action = cast(argparse.Action, FileAppendAction)
851+
action = FileAppendAction
826852
elif inptype["items"] == "Directory":
827-
action = cast(argparse.Action, DirectoryAppendAction)
853+
action = DirectoryAppendAction
828854
else:
829855
action = "append"
830856
elif isinstance(inptype, MutableMapping) and inptype["type"] == "enum":
@@ -851,18 +877,20 @@ def add_argument(
851877
_logger.debug("Can't make command line argument from %s", inptype)
852878
return None
853879

880+
if action in (FileAction, DirectoryAction, FileAppendAction, DirectoryAppendAction):
881+
typekw["urljoin"] = urljoin
882+
typekw["base_uri"] = base_uri
883+
854884
if inptype != "boolean":
855-
typekw = {"type": atype}
856-
else:
857-
typekw = {}
885+
typekw["type"] = atype
858886

859887
toolparser.add_argument(
860888
flag + name,
861889
required=required,
862890
help=ahelp,
863891
action=action, # type: ignore
864892
default=default,
865-
**typekw
893+
**typekw,
866894
)
867895

868896

@@ -872,6 +900,8 @@ def generate_parser(
872900
namemap: Dict[str, str],
873901
records: List[str],
874902
input_required: bool = True,
903+
urljoin: Callable[[str, str], str] = urllib.parse.urljoin,
904+
base_uri: str = "",
875905
) -> argparse.ArgumentParser:
876906
toolparser.description = tool.tool.get("doc", None)
877907
toolparser.add_argument("job_order", nargs="?", help="Job input json file")
@@ -884,7 +914,15 @@ def generate_parser(
884914
description = inp.get("doc", "")
885915
default = inp.get("default", None)
886916
add_argument(
887-
toolparser, name, inptype, records, description, default, input_required
917+
toolparser,
918+
name,
919+
inptype,
920+
records,
921+
description,
922+
default,
923+
input_required,
924+
urljoin,
925+
base_uri,
888926
)
889927

890928
return toolparser

cwltool/executors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def run_jobs(
221221
fsaccess=runtime_context.make_fs_access(""),
222222
)
223223
process.parent_wf = process.provenance_object
224+
224225
jobiter = process.job(job_order_object, self.output_callback, runtime_context)
225226

226227
try:

cwltool/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def init_job_order(
427427
namemap,
428428
records,
429429
input_required,
430+
loader.fetcher.urljoin,
431+
file_uri(os.getcwd()) + "/",
430432
)
431433
if args.tool_help:
432434
toolparser.print_help(cast(IO[str], stdout))

tests/test_toolargparse.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,37 @@
7070
outputs: []
7171
"""
7272

73+
script_d = """
74+
#!/usr/bin/env cwl-runner
75+
76+
cwlVersion: v1.0
77+
class: ExpressionTool
78+
79+
inputs:
80+
foo:
81+
type:
82+
- type: enum
83+
symbols: [cymbal1, cymbal2]
84+
85+
expression: $(inputs.foo)
86+
87+
outputs: []
88+
"""
89+
90+
script_e = """
91+
#!/usr/bin/env cwl-runner
92+
93+
cwlVersion: v1.0
94+
class: ExpressionTool
95+
96+
inputs:
97+
foo: File
98+
99+
expression: '{"bar": $(inputs.foo.location)}'
100+
101+
outputs: []
102+
"""
103+
73104
scripts_argparse_params = [
74105
("help", script_a, lambda x: ["--debug", x, "--input", get_data("tests/echo.cwl")]),
75106
("boolean", script_b, lambda x: [x, "--help"]),
@@ -79,6 +110,16 @@
79110
script_c,
80111
lambda x: [x, "--foo.one", get_data("tests/echo.cwl"), "--foo.two", "test"],
81112
),
113+
(
114+
"foo with d",
115+
script_d,
116+
lambda x: [x, "--foo", "cymbal2"],
117+
),
118+
(
119+
"foo with e",
120+
script_e,
121+
lambda x: [x, "--foo", "http://example.com"],
122+
),
82123
]
83124

84125

@@ -92,7 +133,7 @@ def test_argparse(
92133
with script_name.open(mode="w") as script:
93134
script.write(script_contents)
94135

95-
my_params = ["--outdir", str(tmp_path / "outdir")]
136+
my_params = ["--outdir", str(tmp_path / "outdir"), "--debug"]
96137
my_params.extend(params(script.name))
97138
assert main(my_params) == 0, name
98139

0 commit comments

Comments
 (0)