1
1
# SPDX-License-Identifier: Apache-2.0
2
+ import copy
2
3
import hashlib
3
- from typing import IO , Any , List , MutableSequence , Optional , Tuple , Union , cast
4
+ import logging
5
+ from collections import namedtuple
6
+ from typing import Any , Dict , IO , List , MutableSequence , Optional , Tuple , Union , cast
4
7
5
8
from ruamel import yaml
6
9
from schema_salad .exceptions import ValidationException
7
- from schema_salad .utils import json_dumps
10
+ from schema_salad .sourceline import SourceLine
11
+ from schema_salad .utils import aslist , json_dumps
8
12
9
13
import cwl_utils .parser
10
14
import cwl_utils .parser .cwl_v1_0 as cwl
13
17
14
18
CONTENT_LIMIT : int = 64 * 1024
15
19
20
+ _logger = logging .getLogger ("cwl_utils" )
21
+
22
+ SrcSink = namedtuple ("SrcSink" , ["src" , "sink" , "linkMerge" , "message" ])
23
+
24
+
25
+ def _compare_records (
26
+ src : cwl .RecordSchema , sink : cwl .RecordSchema , strict : bool = False
27
+ ) -> bool :
28
+ """
29
+ Compare two records, ensuring they have compatible fields.
30
+
31
+ This handles normalizing record names, which will be relative to workflow
32
+ step, so that they can be compared.
33
+ """
34
+ srcfields = {cwl .shortname (field .name ): field .type for field in (src .fields or {})}
35
+ sinkfields = {
36
+ cwl .shortname (field .name ): field .type for field in (sink .fields or {})
37
+ }
38
+ for key in sinkfields .keys ():
39
+ if (
40
+ not can_assign_src_to_sink (
41
+ srcfields .get (key , "null" ), sinkfields .get (key , "null" ), strict
42
+ )
43
+ and sinkfields .get (key ) is not None
44
+ ):
45
+ _logger .info (
46
+ "Record comparison failure for %s and %s\n "
47
+ "Did not match fields for %s: %s and %s" ,
48
+ cast (
49
+ Union [cwl .InputRecordSchema , cwl .CommandOutputRecordSchema ], src
50
+ ).name ,
51
+ cast (
52
+ Union [cwl .InputRecordSchema , cwl .CommandOutputRecordSchema ], sink
53
+ ).name ,
54
+ key ,
55
+ srcfields .get (key ),
56
+ sinkfields .get (key ),
57
+ )
58
+ return False
59
+ return True
60
+
16
61
17
62
def _compare_type (type1 : Any , type2 : Any ) -> bool :
18
63
if isinstance (type1 , cwl .ArraySchema ) and isinstance (type2 , cwl .ArraySchema ):
@@ -38,6 +83,115 @@ def _compare_type(type1: Any, type2: Any) -> bool:
38
83
return bool (type1 == type2 )
39
84
40
85
86
+ def can_assign_src_to_sink (src : Any , sink : Any , strict : bool = False ) -> bool :
87
+ """
88
+ Check for identical type specifications, ignoring extra keys like inputBinding.
89
+
90
+ src: admissible source types
91
+ sink: admissible sink types
92
+
93
+ In non-strict comparison, at least one source type must match one sink type,
94
+ except for 'null'.
95
+ In strict comparison, all source types must match at least one sink type.
96
+ """
97
+ if src == "Any" or sink == "Any" :
98
+ return True
99
+ if isinstance (src , cwl .ArraySchema ) and isinstance (sink , cwl .ArraySchema ):
100
+ return can_assign_src_to_sink (src .items , sink .items , strict )
101
+ if isinstance (src , cwl .RecordSchema ) and isinstance (sink , cwl .RecordSchema ):
102
+ return _compare_records (src , sink , strict )
103
+ if isinstance (src , MutableSequence ):
104
+ if strict :
105
+ for this_src in src :
106
+ if not can_assign_src_to_sink (this_src , sink ):
107
+ return False
108
+ return True
109
+ for this_src in src :
110
+ if this_src != "null" and can_assign_src_to_sink (this_src , sink ):
111
+ return True
112
+ return False
113
+ if isinstance (sink , MutableSequence ):
114
+ for this_sink in sink :
115
+ if can_assign_src_to_sink (src , this_sink ):
116
+ return True
117
+ return False
118
+ return bool (src == sink )
119
+
120
+
121
+ def check_all_types (
122
+ src_dict : Dict [str , Any ],
123
+ sinks : MutableSequence [Union [cwl .WorkflowStepInput , cwl .WorkflowOutputParameter ]],
124
+ type_dict : Dict [str , Any ],
125
+ ) -> Dict [str , List [SrcSink ]]:
126
+ """Given a list of sinks, check if their types match with the types of their sources."""
127
+ validation : Dict [str , List [SrcSink ]] = {"warning" : [], "exception" : []}
128
+ for sink in sinks :
129
+ if isinstance (sink , cwl .WorkflowOutputParameter ):
130
+ sourceName = "outputSource"
131
+ sourceField = sink .outputSource
132
+ elif isinstance (sink , cwl .WorkflowStepInput ):
133
+ sourceName = "source"
134
+ sourceField = sink .source
135
+ else :
136
+ continue
137
+ if sourceField is not None :
138
+ if isinstance (sourceField , MutableSequence ):
139
+ linkMerge = sink .linkMerge or (
140
+ "merge_nested" if len (sourceField ) > 1 else None
141
+ )
142
+ srcs_of_sink = []
143
+ for parm_id in sourceField :
144
+ srcs_of_sink += [src_dict [parm_id ]]
145
+ else :
146
+ parm_id = cast (str , sourceField )
147
+ if parm_id not in src_dict :
148
+ raise SourceLine (sink , sourceName , ValidationException ).makeError (
149
+ f"{ sourceName } not found: { parm_id } "
150
+ )
151
+ srcs_of_sink = [src_dict [parm_id ]]
152
+ linkMerge = None
153
+ for src in srcs_of_sink :
154
+ check_result = check_types (
155
+ type_dict [cast (str , src .id )],
156
+ type_dict [cast (str , sink .id )],
157
+ linkMerge ,
158
+ getattr (sink , "valueFrom" , None ),
159
+ )
160
+ if check_result == "warning" :
161
+ validation ["warning" ].append (SrcSink (src , sink , linkMerge , None ))
162
+ elif check_result == "exception" :
163
+ validation ["exception" ].append (SrcSink (src , sink , linkMerge , None ))
164
+ return validation
165
+
166
+
167
+ def check_types (
168
+ srctype : Any ,
169
+ sinktype : Any ,
170
+ linkMerge : Optional [str ],
171
+ valueFrom : Optional [str ] = None ,
172
+ ) -> str :
173
+ """
174
+ Check if the source and sink types are correct.
175
+
176
+ Acceptable types are "pass", "warning", or "exception".
177
+ """
178
+ if valueFrom is not None :
179
+ return "pass"
180
+ if linkMerge is None :
181
+ if can_assign_src_to_sink (srctype , sinktype , strict = True ):
182
+ return "pass"
183
+ if can_assign_src_to_sink (srctype , sinktype , strict = False ):
184
+ return "warning"
185
+ return "exception"
186
+ if linkMerge == "merge_nested" :
187
+ return check_types (
188
+ cwl .ArraySchema (items = srctype , type = "array" ), sinktype , None , None
189
+ )
190
+ if linkMerge == "merge_flattened" :
191
+ return check_types (merge_flatten_type (srctype ), sinktype , None , None )
192
+ raise ValidationException (f"Invalid value { linkMerge } for linkMerge field." )
193
+
194
+
41
195
def content_limit_respected_read_bytes (f : IO [bytes ]) -> bytes :
42
196
"""
43
197
Read file content up to 64 kB as a byte array.
@@ -96,6 +250,59 @@ def merge_flatten_type(src: Any) -> Any:
96
250
return cwl .ArraySchema (type = "array" , items = src )
97
251
98
252
253
+ def type_for_step_input (
254
+ step : cwl .WorkflowStep ,
255
+ in_ : cwl .WorkflowStepInput ,
256
+ ) -> Any :
257
+ """Determine the type for the given step input."""
258
+ if in_ .valueFrom is not None :
259
+ return "Any"
260
+ step_run = cwl_utils .parser .utils .load_step (step )
261
+ cwl_utils .parser .utils .convert_stdstreams_to_files (step_run )
262
+ if step_run and step_run .inputs :
263
+ for step_input in step_run .inputs :
264
+ if (
265
+ cast (str , step_input .id ).split ("#" )[- 1 ]
266
+ == cast (str , in_ .id ).split ("#" )[- 1 ]
267
+ ):
268
+ input_type = step_input .type
269
+ if step .scatter is not None and in_ .id in aslist (step .scatter ):
270
+ input_type = cwl .ArraySchema (items = input_type , type = "array" )
271
+ return input_type
272
+ return "Any"
273
+
274
+
275
+ def type_for_step_output (
276
+ step : cwl .WorkflowStep ,
277
+ sourcename : str ,
278
+ ) -> Any :
279
+ """Determine the type for the given step output."""
280
+ step_run = cwl_utils .parser .utils .load_step (step )
281
+ cwl_utils .parser .utils .convert_stdstreams_to_files (step_run )
282
+ if step_run and step_run .outputs :
283
+ for step_output in step_run .outputs :
284
+ if (
285
+ step_output .id .split ("#" )[- 1 ].split ("/" )[- 1 ]
286
+ == sourcename .split ("#" )[- 1 ].split ("/" )[- 1 ]
287
+ ):
288
+ output_type = step_output .type
289
+ if step .scatter is not None :
290
+ if step .scatterMethod == "nested_crossproduct" :
291
+ for _ in range (len (aslist (step .scatter ))):
292
+ output_type = cwl .ArraySchema (
293
+ items = output_type , type = "array"
294
+ )
295
+ else :
296
+ output_type = cwl .ArraySchema (items = output_type , type = "array" )
297
+ return output_type
298
+ raise ValidationException (
299
+ "param {} not found in {}." .format (
300
+ sourcename ,
301
+ yaml .main .round_trip_dump (cwl .save (step_run )),
302
+ )
303
+ )
304
+
305
+
99
306
def type_for_source (
100
307
process : Union [cwl .CommandLineTool , cwl .Workflow , cwl .ExpressionTool ],
101
308
sourcenames : Union [str , List [str ]],
@@ -142,7 +349,7 @@ def type_for_source(
142
349
return cwl .ArraySchema (items = new_type , type = "array" )
143
350
elif linkMerge == "merge_flattened" :
144
351
return merge_flatten_type (new_type )
145
- elif isinstance (sourcenames , List ):
352
+ elif isinstance (sourcenames , List ) and len ( sourcenames ) > 1 :
146
353
return cwl .ArraySchema (items = new_type , type = "array" )
147
354
else :
148
355
return new_type
@@ -181,26 +388,14 @@ def param_for_source_id(
181
388
== step .id .split ("#" )[- 1 ]
182
389
and step .out
183
390
):
391
+ step_run = cwl_utils .parser .utils .load_step (step )
392
+ cwl_utils .parser .utils .convert_stdstreams_to_files (step_run )
184
393
for outp in step .out :
185
394
outp_id = outp if isinstance (outp , str ) else outp .id
186
395
if (
187
396
outp_id .split ("#" )[- 1 ].split ("/" )[- 1 ]
188
397
== sourcename .split ("#" )[- 1 ].split ("/" )[- 1 ]
189
398
):
190
- step_run = step .run
191
- if isinstance (step .run , str ):
192
- step_run = cwl_utils .parser .load_document_by_uri (
193
- path = target .loadingOptions .fetcher .urljoin (
194
- base_url = cast (
195
- str , target .loadingOptions .fileuri
196
- ),
197
- url = step .run ,
198
- ),
199
- loadingOptions = target .loadingOptions ,
200
- )
201
- cwl_utils .parser .utils .convert_stdstreams_to_files (
202
- step_run
203
- )
204
399
if step_run and step_run .outputs :
205
400
for output in step_run .outputs :
206
401
if (
0 commit comments