Skip to content

Commit fe53fb8

Browse files
AlexTatemr-c
authored andcommitted
Fixing ReceiveScatterOutput to allow outputs to be collected from successful steps when 1) these steps are upstream from a scattered subworkflow, 2) the workflow kill switch is activated by one of the scatter jobs, and 3) on_error==kill
1 parent 5d92d2d commit fe53fb8

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

cwltool/workflow_job.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def completed(self) -> int:
8989
"""The number of completed internal jobs."""
9090
return len(self._completed)
9191

92-
def receive_scatter_output(self, index: int, jobout: CWLObjectType, processStatus: str) -> None:
92+
def receive_scatter_output(self, index: int, runtimeContext: RuntimeContext, jobout: CWLObjectType, processStatus: str) -> None:
9393
"""Record the results of a scatter operation."""
9494
for key, val in jobout.items():
9595
self.dest[key][index] = val
@@ -102,6 +102,8 @@ def receive_scatter_output(self, index: int, jobout: CWLObjectType, processStatu
102102
if processStatus != "success":
103103
if self.processStatus != "permanentFail":
104104
self.processStatus = processStatus
105+
if runtimeContext.on_error == "kill":
106+
self.output_callback(self.dest, self.processStatus)
105107

106108
if index not in self._completed:
107109
self._completed.add(index)
@@ -156,7 +158,7 @@ def parallel_steps(
156158
except WorkflowException as exc:
157159
_logger.error("Cannot make scatter job: %s", str(exc))
158160
_logger.debug("", exc_info=True)
159-
rc.receive_scatter_output(index, {}, "permanentFail")
161+
rc.receive_scatter_output(index, runtimeContext, {}, "permanentFail")
160162
if not made_progress and rc.completed < rc.total:
161163
yield None
162164

@@ -185,7 +187,7 @@ def nested_crossproduct_scatter(
185187
if len(scatter_keys) == 1:
186188
if runtimeContext.postScatterEval is not None:
187189
sjob = runtimeContext.postScatterEval(sjob)
188-
curriedcallback = functools.partial(rc.receive_scatter_output, index)
190+
curriedcallback = functools.partial(rc.receive_scatter_output, index, runtimeContext)
189191
if sjob is not None:
190192
steps.append(process.job(sjob, curriedcallback, runtimeContext))
191193
else:
@@ -197,7 +199,7 @@ def nested_crossproduct_scatter(
197199
process,
198200
sjob,
199201
scatter_keys[1:],
200-
functools.partial(rc.receive_scatter_output, index),
202+
functools.partial(rc.receive_scatter_output, index, runtimeContext),
201203
runtimeContext,
202204
)
203205
)
@@ -257,7 +259,7 @@ def _flat_crossproduct_scatter(
257259
if len(scatter_keys) == 1:
258260
if runtimeContext.postScatterEval is not None:
259261
sjob = runtimeContext.postScatterEval(sjob)
260-
curriedcallback = functools.partial(callback.receive_scatter_output, put)
262+
curriedcallback = functools.partial(callback.receive_scatter_output, put, runtimeContext)
261263
if sjob is not None:
262264
steps.append(process.job(sjob, curriedcallback, runtimeContext))
263265
else:
@@ -307,7 +309,7 @@ def dotproduct_scatter(
307309

308310
if runtimeContext.postScatterEval is not None:
309311
sjobo = runtimeContext.postScatterEval(sjobo)
310-
curriedcallback = functools.partial(rc.receive_scatter_output, index)
312+
curriedcallback = functools.partial(rc.receive_scatter_output, index, runtimeContext)
311313
if sjobo is not None:
312314
steps.append(process.job(sjobo, curriedcallback, runtimeContext))
313315
else:

0 commit comments

Comments
 (0)