Skip to content

Commit 48166cc

Browse files
GlassOfWhiskeymr-c
andauthored
Improve type_for_source function behaviour (#164)
* Return ArraySchema as type for multiple sources When a `source` or ``outputSource` field contains a list of sources, return an `ArraySchema` object as the type for source. * Added linkMerge management * Added pickValue * Better handling exception message * Added scatter source management * Added nested_crossproduct scatter tests * Added flat crossproduct test * Added automatic loading of steps * Added single source tests * Added stdout to file conversion * Added unit tests for _compare_types Co-authored-by: Michael R. Crusoe <[email protected]>
1 parent 57e2c13 commit 48166cc

File tree

60 files changed

+2361
-179
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+2361
-179
lines changed

Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ EXTRAS=
2626

2727
# `SHELL=bash` doesn't work for some, so don't use BASH-isms like
2828
# `[[` conditional expressions.
29-
PYSOURCES=$(filter-out $(MODULE)/parser/cwl_v%,$(shell find $(MODULE) -name "*.py")) $(wildcard tests/*.py) create_cwl_from_objects.py load_cwl_by_path.py setup.py
29+
PYSOURCES=$(filter-out $(MODULE)/parser/cwl_v%,$(shell find $(MODULE) -name "*.py")) \
30+
$(wildcard tests/*.py) create_cwl_from_objects.py load_cwl_by_path.py \
31+
setup.py ${MODULE}/parser/cwl_v1_?_utils.py
3032
DEVPKGS=diff_cover black pylint pep257 pydocstyle flake8 tox tox-pyenv \
3133
isort wheel autoflake flake8-bugbear pyupgrade bandit \
3234
-rtest-requirements.txt -rmypy-requirements.txt

cwl_utils/cwl_v1_0_expression_refactor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1939,7 +1939,7 @@ def replace_step_valueFrom_expr_with_etool(
19391939
step_inp: cwl.WorkflowStepInput,
19401940
original_process: Union[cwl.CommandLineTool, cwl.ExpressionTool],
19411941
original_step_ins: List[cwl.WorkflowStepInput],
1942-
source: Union[str, List[str]],
1942+
source: Optional[Union[str, List[str]]],
19431943
replace_etool: bool,
19441944
source_type: Optional[Union[cwl.InputParameter, List[cwl.InputParameter]]] = None,
19451945
) -> None:

cwl_utils/cwl_v1_1_expression_refactor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1931,7 +1931,7 @@ def replace_step_valueFrom_expr_with_etool(
19311931
step_inp: cwl.WorkflowStepInput,
19321932
original_process: Union[cwl.CommandLineTool, cwl.ExpressionTool],
19331933
original_step_ins: List[cwl.WorkflowStepInput],
1934-
source: Union[str, List[str]],
1934+
source: Optional[Union[str, List[str]]],
19351935
replace_etool: bool,
19361936
source_type: Optional[
19371937
Union[cwl.WorkflowInputParameter, List[cwl.WorkflowInputParameter]]

cwl_utils/cwl_v1_2_expression_refactor.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -702,18 +702,22 @@ def process_workflow_inputs_and_outputs(
702702
target_type.name = None
703703
target = cwl.WorkflowInputParameter(id=None, type=target_type)
704704
if not isinstance(param2.outputSource, list):
705-
sources: Union[List[str], str] = param2.outputSource.split("#")[-1]
705+
sources = param2.outputSource.split("#")[-1]
706706
else:
707707
sources = [s.split("#")[-1] for s in param2.outputSource]
708708
source_type_items = utils.type_for_source(workflow, sources)
709-
if "null" not in source_type_items:
710-
if isinstance(source_type_items, list):
709+
if isinstance(source_type_items, cwl.ArraySchema):
710+
if isinstance(source_type_items.items, list):
711+
if "null" not in source_type_items.items:
712+
source_type_items.items.append("null")
713+
elif source_type_items.items != "null":
714+
source_type_items.items = ["null", source_type_items.items]
715+
elif isinstance(source_type_items, list):
716+
if "null" not in source_type_items:
711717
source_type_items.append("null")
712-
else:
713-
source_type_items = ["null", source_type_items]
714-
source_type = cwl.CommandInputParameter(
715-
type=cwl.ArraySchema(type="array", items=source_type_items)
716-
)
718+
elif source_type_items != "null":
719+
source_type_items = ["null", source_type_items]
720+
source_type = cwl.CommandInputParameter(type=source_type_items)
717721
replace_expr_with_etool(
718722
expression,
719723
etool_id,
@@ -2030,7 +2034,7 @@ def replace_step_valueFrom_expr_with_etool(
20302034
step_inp: cwl.WorkflowStepInput,
20312035
original_process: Union[cwl.CommandLineTool, cwl.ExpressionTool],
20322036
original_step_ins: List[cwl.WorkflowStepInput],
2033-
source: Union[str, List[str]],
2037+
source: Optional[Union[str, List[str]]],
20342038
replace_etool: bool,
20352039
source_type: Optional[
20362040
Union[cwl.WorkflowInputParameter, List[cwl.WorkflowInputParameter]]

cwl_utils/parser/cwl_v1_0_utils.py

Lines changed: 156 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,43 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import hashlib
3-
from typing import Any, IO, List, Optional, Union
3+
from typing import Any, IO, List, MutableSequence, Optional, Tuple, Union, cast
44

55
from ruamel import yaml
66
from schema_salad.exceptions import ValidationException
77
from schema_salad.utils import json_dumps
88

9+
import cwl_utils.parser
910
import cwl_utils.parser.cwl_v1_0 as cwl
11+
import cwl_utils.parser.utils
1012
from cwl_utils.errors import WorkflowException
1113

12-
1314
CONTENT_LIMIT: int = 64 * 1024
1415

1516

17+
def _compare_type(type1: Any, type2: Any) -> bool:
18+
if isinstance(type1, cwl.ArraySchema) and isinstance(type2, cwl.ArraySchema):
19+
return _compare_type(type1.items, type2.items)
20+
elif isinstance(type1, cwl.RecordSchema) and isinstance(type2, cwl.RecordSchema):
21+
fields1 = {
22+
cwl.shortname(field.name): field.type for field in (type1.fields or {})
23+
}
24+
fields2 = {
25+
cwl.shortname(field.name): field.type for field in (type2.fields or {})
26+
}
27+
if fields1.keys() != fields2.keys():
28+
return False
29+
return all((_compare_type(fields1[k], fields2[k]) for k in fields1.keys()))
30+
elif isinstance(type1, MutableSequence) and isinstance(type2, MutableSequence):
31+
if len(type1) != len(type2):
32+
return False
33+
for t1 in type1:
34+
if not any((_compare_type(t1, t2) for t2 in type2)):
35+
return False
36+
return True
37+
else:
38+
return bool(type1 == type2)
39+
40+
1641
def content_limit_respected_read_bytes(f: IO[bytes]) -> bytes:
1742
"""
1843
Read file content up to 64 kB as a byte array.
@@ -32,49 +57,102 @@ def content_limit_respected_read(f: IO[bytes]) -> str:
3257

3358

3459
def convert_stdstreams_to_files(clt: cwl.CommandLineTool) -> None:
60+
"""Convert stdout and stderr type shortcuts to files."""
3561
for out in clt.outputs:
36-
if out.type == 'stdout':
62+
if out.type == "stdout":
3763
if out.outputBinding is not None:
3864
raise ValidationException(
39-
"Not allowed to specify outputBinding when using stdout shortcut.")
65+
"Not allowed to specify outputBinding when using stdout shortcut."
66+
)
4067
if clt.stdout is None:
41-
clt.stdout = str(hashlib.sha1(json_dumps( # nosec
42-
clt.save(), sort_keys=True).encode('utf-8')).hexdigest())
43-
out.type = 'File'
68+
clt.stdout = str(
69+
hashlib.sha1( # nosec
70+
json_dumps(clt.save(), sort_keys=True).encode("utf-8")
71+
).hexdigest()
72+
)
73+
out.type = "File"
4474
out.outputBinding = cwl.CommandOutputBinding(glob=clt.stdout)
45-
elif out.type == 'stderr':
75+
elif out.type == "stderr":
4676
if out.outputBinding is not None:
4777
raise ValidationException(
48-
"Not allowed to specify outputBinding when using stderr shortcut.")
78+
"Not allowed to specify outputBinding when using stderr shortcut."
79+
)
4980
if clt.stderr is None:
50-
clt.stderr = str(hashlib.sha1(json_dumps( # nosec
51-
clt.save(), sort_keys=True).encode('utf-8')).hexdigest())
52-
out.type = 'File'
81+
clt.stderr = str(
82+
hashlib.sha1( # nosec
83+
json_dumps(clt.save(), sort_keys=True).encode("utf-8")
84+
).hexdigest()
85+
)
86+
out.type = "File"
5387
out.outputBinding = cwl.CommandOutputBinding(glob=clt.stderr)
5488

5589

90+
def merge_flatten_type(src: Any) -> Any:
91+
"""Return the merge flattened type of the source type."""
92+
if isinstance(src, MutableSequence):
93+
return [merge_flatten_type(t) for t in src]
94+
if isinstance(src, cwl.ArraySchema):
95+
return src
96+
return cwl.ArraySchema(type="array", items=src)
97+
98+
5699
def type_for_source(
57100
process: Union[cwl.CommandLineTool, cwl.Workflow, cwl.ExpressionTool],
58101
sourcenames: Union[str, List[str]],
59102
parent: Optional[cwl.Workflow] = None,
60-
) -> Union[List[Any], Any]:
103+
linkMerge: Optional[str] = None,
104+
) -> Any:
61105
"""Determine the type for the given sourcenames."""
62-
params = param_for_source_id(process, sourcenames, parent)
106+
scatter_context: List[Optional[Tuple[int, str]]] = []
107+
params = param_for_source_id(process, sourcenames, parent, scatter_context)
63108
if not isinstance(params, list):
64-
return params.type
65-
new_type: List[Any] = []
66-
for p in params:
67-
if isinstance(p, str) and p not in new_type:
68-
new_type.append(p)
69-
elif hasattr(p, "type") and p.type not in new_type:
70-
new_type.append(p.type)
71-
return new_type
109+
new_type = params.type
110+
if scatter_context[0] is not None:
111+
if scatter_context[0][1] == "nested_crossproduct":
112+
for _ in range(scatter_context[0][0]):
113+
new_type = cwl.ArraySchema(items=new_type, type="array")
114+
else:
115+
new_type = cwl.ArraySchema(items=new_type, type="array")
116+
if linkMerge == "merge_nested":
117+
new_type = cwl.ArraySchema(items=new_type, type="array")
118+
elif linkMerge == "merge_flattened":
119+
new_type = merge_flatten_type(new_type)
120+
return new_type
121+
new_type = []
122+
for p, sc in zip(params, scatter_context):
123+
if isinstance(p, str) and not any((_compare_type(t, p) for t in new_type)):
124+
cur_type = p
125+
elif hasattr(p, "type") and not any(
126+
(_compare_type(t, p.type) for t in new_type)
127+
):
128+
cur_type = p.type
129+
else:
130+
cur_type = None
131+
if cur_type is not None:
132+
if sc is not None:
133+
if sc[1] == "nested_crossproduct":
134+
for _ in range(sc[0]):
135+
cur_type = cwl.ArraySchema(items=cur_type, type="array")
136+
else:
137+
cur_type = cwl.ArraySchema(items=cur_type, type="array")
138+
new_type.append(cur_type)
139+
if len(new_type) == 1:
140+
new_type = new_type[0]
141+
if linkMerge == "merge_nested":
142+
return cwl.ArraySchema(items=new_type, type="array")
143+
elif linkMerge == "merge_flattened":
144+
return merge_flatten_type(new_type)
145+
elif isinstance(sourcenames, List):
146+
return cwl.ArraySchema(items=new_type, type="array")
147+
else:
148+
return new_type
72149

73150

74151
def param_for_source_id(
75152
process: Union[cwl.CommandLineTool, cwl.Workflow, cwl.ExpressionTool],
76153
sourcenames: Union[str, List[str]],
77154
parent: Optional[cwl.Workflow] = None,
155+
scatter_context: Optional[List[Optional[Tuple[int, str]]]] = None,
78156
) -> Union[List[cwl.InputParameter], cwl.InputParameter]:
79157
"""Find the process input parameter that matches one of the given sourcenames."""
80158
if isinstance(sourcenames, str):
@@ -85,6 +163,8 @@ def param_for_source_id(
85163
for param in process.inputs:
86164
if param.id.split("#")[-1] == sourcename.split("#")[-1]:
87165
params.append(param)
166+
if scatter_context is not None:
167+
scatter_context.append(None)
88168
targets = [process]
89169
if parent:
90170
targets.append(parent)
@@ -93,26 +173,72 @@ def param_for_source_id(
93173
for inp in target.inputs:
94174
if inp.id.split("#")[-1] == sourcename.split("#")[-1]:
95175
params.append(inp)
176+
if scatter_context is not None:
177+
scatter_context.append(None)
96178
for step in target.steps:
97-
if sourcename.split("#")[-1].split("/")[0] == step.id.split("#")[-1] and step.out:
179+
if (
180+
"/".join(sourcename.split("#")[-1].split("/")[:-1])
181+
== step.id.split("#")[-1]
182+
and step.out
183+
):
98184
for outp in step.out:
99185
outp_id = outp if isinstance(outp, str) else outp.id
100-
if outp_id.split("#")[-1].split("/")[-1] == sourcename.split("#")[-1].split("/", 1)[1]:
101-
if step.run and step.run.outputs:
102-
for output in step.run.outputs:
186+
if (
187+
outp_id.split("#")[-1].split("/")[-1]
188+
== sourcename.split("#")[-1].split("/")[-1]
189+
):
190+
step_run = step.run
191+
if isinstance(step.run, str):
192+
step_run = cwl_utils.parser.load_document_by_uri(
193+
path=target.loadingOptions.fetcher.urljoin(
194+
base_url=cast(
195+
str, target.loadingOptions.fileuri
196+
),
197+
url=step.run,
198+
),
199+
loadingOptions=target.loadingOptions,
200+
)
201+
cwl_utils.parser.utils.convert_stdstreams_to_files(
202+
step_run
203+
)
204+
if step_run and step_run.outputs:
205+
for output in step_run.outputs:
103206
if (
104-
output.id.split("#")[-1].split('/')[-1]
105-
== sourcename.split('#')[-1].split("/", 1)[1]
207+
output.id.split("#")[-1].split("/")[-1]
208+
== sourcename.split("#")[-1].split("/")[-1]
106209
):
107210
params.append(output)
211+
if scatter_context is not None:
212+
if isinstance(step.scatter, str):
213+
scatter_context.append(
214+
(
215+
1,
216+
step.scatterMethod
217+
or "dotproduct",
218+
)
219+
)
220+
elif isinstance(
221+
step.scatter, MutableSequence
222+
):
223+
scatter_context.append(
224+
(
225+
len(step.scatter),
226+
step.scatterMethod
227+
or "dotproduct",
228+
)
229+
)
230+
else:
231+
scatter_context.append(None)
108232
if len(params) == 1:
109233
return params[0]
110234
elif len(params) > 1:
111235
return params
112236
raise WorkflowException(
113-
"param {} not found in {}\n or\n {}.".format(
237+
"param {} not found in {}\n{}.".format(
114238
sourcename,
115239
yaml.main.round_trip_dump(cwl.save(process)),
116-
yaml.main.round_trip_dump(cwl.save(parent)),
240+
" or\n {}".format(yaml.main.round_trip_dump(cwl.save(parent)))
241+
if parent is not None
242+
else "",
117243
)
118244
)

0 commit comments

Comments
 (0)