Skip to content

Commit 54b2ad4

Browse files
committed
accepting None as an output from function that is expected to return multiple values
1 parent 9de9584 commit 54b2ad4

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

pydra/engine/task.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,20 @@ def _run_task(self):
200200
self.output_ = None
201201
output = cp.loads(self.inputs._func)(**inputs)
202202
output_names = [el[0] for el in self.output_spec.fields]
203-
self.output_ = {}
204-
if len(output_names) > 1:
205-
if len(output_names) == len(output):
206-
self.output_ = dict(zip(output_names, output))
203+
if output is None:
204+
self.output_ = dict((nm, None) for nm in output_names)
205+
else:
206+
if len(output_names) == 1:
207+
# if only one element in the fields, everything should be returned together
208+
self.output_ = {output_names[0]: output}
207209
else:
208-
raise Exception(
209-
f"expected {len(self.output_spec.fields)} elements, "
210-
f"but {len(output)} were returned"
211-
)
212-
else: # if only one element in the fields, everything should be returned together
213-
self.output_[output_names[0]] = output
210+
if isinstance(output, tuple) and len(output_names) == len(output):
211+
self.output_ = dict(zip(output_names, output))
212+
else:
213+
raise Exception(
214+
f"expected {len(self.output_spec.fields)} elements, "
215+
f"but {output} were returned"
216+
)
214217

215218

216219
class ShellCommandTask(TaskBase):

pydra/engine/tests/test_task.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def raise_exception(c, d):
309309
assert pytest.raises(Exception, bad_funk)
310310

311311

312-
def test_result_none():
312+
def test_result_none_1():
313313
""" checking if None is properly returned as the result"""
314314

315315
@mark.task
@@ -321,6 +321,19 @@ def fun_none(x):
321321
assert res.output.out is None
322322

323323

324+
def test_result_none_2():
325+
""" checking if None is properly set for all outputs """
326+
327+
@mark.task
328+
def fun_none(x) -> (ty.Any, ty.Any):
329+
return None
330+
331+
task = fun_none(name="none", x=3)
332+
res = task()
333+
assert res.output.out1 is None
334+
assert res.output.out2 is None
335+
336+
324337
def test_audit_prov(tmpdir):
325338
@mark.task
326339
def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]):

0 commit comments

Comments
 (0)