Skip to content

Commit 78db0a1

Browse files
committed
freshen cwltool/load_tool.py
1 parent 822fe18 commit 78db0a1

File tree

1 file changed

+60
-25
lines changed

1 file changed

+60
-25
lines changed

cwltool/load_tool.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .loghandler import _logger
4141
from .process import Process, get_schema, shortname
4242
from .update import ALLUPDATES
43-
from .utils import ResolverType, visit_class
43+
from .utils import CWLObjectType, ResolverType, visit_class
4444

4545
jobloaderctx = {
4646
"cwl": "https://w3id.org/cwl/cwl#",
@@ -108,7 +108,7 @@ def resolve_tool_uri(
108108

109109

110110
def fetch_document(
111-
argsworkflow: Union[str, Dict[str, Any]],
111+
argsworkflow: Union[str, CWLObjectType],
112112
loadingContext: Optional[LoadingContext] = None,
113113
) -> Tuple[LoadingContext, CommentedMap, str]:
114114
"""Retrieve a CWL document."""
@@ -132,17 +132,23 @@ def fetch_document(
132132
)
133133
workflowobj = cast(CommentedMap, loadingContext.loader.fetch(fileuri))
134134
return loadingContext, workflowobj, uri
135-
if isinstance(argsworkflow, dict):
136-
uri = argsworkflow["id"] if argsworkflow.get("id") else "_:" + str(uuid.uuid4())
137-
workflowobj = cast(CommentedMap, cmap(argsworkflow, fn=uri))
135+
if isinstance(argsworkflow, MutableMapping):
136+
uri = (
137+
cast(str, argsworkflow["id"])
138+
if argsworkflow.get("id")
139+
else "_:" + str(uuid.uuid4())
140+
)
141+
workflowobj = cast(
142+
CommentedMap, cmap(cast(Dict[str, Any], argsworkflow), fn=uri)
143+
)
138144
loadingContext.loader.idx[uri] = workflowobj
139145
return loadingContext, workflowobj, uri
140146
raise ValidationException("Must be URI or object: '%s'" % argsworkflow)
141147

142148

143149
def _convert_stdstreams_to_files(
144150
workflowobj: Union[
145-
MutableMapping[str, Any], MutableSequence[Union[Dict[str, Any], str, int]], str
151+
CWLObjectType, MutableSequence[Union[CWLObjectType, str, int]], str
146152
]
147153
) -> None:
148154
if isinstance(workflowobj, MutableMapping):
@@ -156,7 +162,9 @@ def _convert_stdstreams_to_files(
156162
outputs = workflowobj.get("outputs", [])
157163
if not isinstance(outputs, CommentedSeq):
158164
raise ValidationException('"outputs" section is not ' "valid.")
159-
for out in workflowobj.get("outputs", []):
165+
for out in cast(
166+
MutableSequence[CWLObjectType], workflowobj.get("outputs", [])
167+
):
160168
if not isinstance(out, CommentedMap):
161169
raise ValidationException(
162170
"Output '{}' is not a valid " "OutputParameter.".format(out)
@@ -181,7 +189,9 @@ def _convert_stdstreams_to_files(
181189
workflowobj[streamtype] = filename
182190
out["type"] = "File"
183191
out["outputBinding"] = cmap({"glob": filename})
184-
for inp in workflowobj.get("inputs", []):
192+
for inp in cast(
193+
MutableSequence[CWLObjectType], workflowobj.get("inputs", [])
194+
):
185195
if inp.get("type") == "stdin":
186196
if "inputBinding" in inp:
187197
raise ValidationException(
@@ -195,21 +205,38 @@ def _convert_stdstreams_to_files(
195205
)
196206
else:
197207
workflowobj["stdin"] = (
198-
"$(inputs.%s.path)" % inp["id"].rpartition("#")[2]
208+
"$(inputs.%s.path)"
209+
% cast(str, inp["id"]).rpartition("#")[2]
199210
)
200211
inp["type"] = "File"
201212
else:
202213
for entry in workflowobj.values():
203-
_convert_stdstreams_to_files(entry)
214+
_convert_stdstreams_to_files(
215+
cast(
216+
Union[
217+
CWLObjectType,
218+
MutableSequence[Union[CWLObjectType, str, int]],
219+
str,
220+
],
221+
entry,
222+
)
223+
)
204224
if isinstance(workflowobj, MutableSequence):
205225
for entry in workflowobj:
206-
_convert_stdstreams_to_files(entry)
226+
_convert_stdstreams_to_files(
227+
cast(
228+
Union[
229+
CWLObjectType,
230+
MutableSequence[Union[CWLObjectType, str, int]],
231+
str,
232+
],
233+
entry,
234+
)
235+
)
207236

208237

209238
def _add_blank_ids(
210-
workflowobj: Union[
211-
MutableMapping[str, Any], MutableSequence[Union[MutableMapping[str, Any], str]]
212-
]
239+
workflowobj: Union[CWLObjectType, MutableSequence[Union[CWLObjectType, str]]]
213240
) -> None:
214241
if isinstance(workflowobj, MutableMapping):
215242
if (
@@ -220,10 +247,20 @@ def _add_blank_ids(
220247
):
221248
workflowobj["run"]["id"] = str(uuid.uuid4())
222249
for entry in workflowobj.values():
223-
_add_blank_ids(entry)
250+
_add_blank_ids(
251+
cast(
252+
Union[CWLObjectType, MutableSequence[Union[CWLObjectType, str]]],
253+
entry,
254+
)
255+
)
224256
if isinstance(workflowobj, MutableSequence):
225257
for entry in workflowobj:
226-
_add_blank_ids(entry)
258+
_add_blank_ids(
259+
cast(
260+
Union[CWLObjectType, MutableSequence[Union[CWLObjectType, str]]],
261+
entry,
262+
)
263+
)
227264

228265

229266
def resolve_and_validate_document(
@@ -263,8 +300,8 @@ def resolve_and_validate_document(
263300
if not cwlVersion and fileuri != uri:
264301
# The tool we're loading is a fragment of a bigger file. Get
265302
# the document root element and look for cwlVersion there.
266-
metadata = fetch_document(fileuri, loadingContext)[1] # type: Dict[str, Any]
267-
cwlVersion = metadata.get("cwlVersion")
303+
metadata = cast(CWLObjectType, fetch_document(fileuri, loadingContext)[1])
304+
cwlVersion = cast(str, metadata.get("cwlVersion"))
268305
if not cwlVersion:
269306
raise ValidationException(
270307
"No cwlVersion found. "
@@ -427,7 +464,7 @@ def make_tool(
427464

428465

429466
def load_tool(
430-
argsworkflow: Union[str, Dict[str, Any]],
467+
argsworkflow: Union[str, CWLObjectType],
431468
loadingContext: Optional[LoadingContext] = None,
432469
) -> Process:
433470

@@ -441,19 +478,17 @@ def load_tool(
441478

442479

443480
def resolve_overrides(
444-
ov, # type: IdxResultType
445-
ov_uri, # type: str
446-
baseurl, # type: str
447-
): # type: (...) -> List[Dict[str, Any]]
481+
ov: IdxResultType, ov_uri: str, baseurl: str,
482+
) -> List[CWLObjectType]:
448483
ovloader = Loader(overrides_ctx)
449484
ret, _ = ovloader.resolve_all(ov, baseurl)
450485
if not isinstance(ret, CommentedMap):
451486
raise Exception("Expected CommentedMap, got %s" % type(ret))
452487
cwl_docloader = get_schema("v1.0")[0]
453488
cwl_docloader.resolve_all(ret, ov_uri)
454-
return cast(List[Dict[str, Any]], ret["http://commonwl.org/cwltool#overrides"])
489+
return cast(List[CWLObjectType], ret["http://commonwl.org/cwltool#overrides"])
455490

456491

457-
def load_overrides(ov, base_url): # type: (str, str) -> List[Dict[str, Any]]
492+
def load_overrides(ov: str, base_url: str) -> List[CWLObjectType]:
458493
ovloader = Loader(overrides_ctx)
459494
return resolve_overrides(ovloader.fetch(ov), ov, base_url)

0 commit comments

Comments
 (0)