Skip to content

Commit 2b0a47d

Browse files
committed
added tests to test Python functions with no outputs
1 parent 997bbe2 commit 2b0a47d

File tree

2 files changed

+74
-38
lines changed

2 files changed

+74
-38
lines changed

pydra/compose/base/helpers.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def extract_function_inputs_and_outputs(
183183
input_types[p.name] = type_hints.get(p.name, ty.Any)
184184
if p.default is not inspect.Parameter.empty:
185185
input_defaults[p.name] = p.default
186-
if inputs:
186+
if inputs is not None:
187187
if not isinstance(inputs, dict):
188188
raise ValueError(
189189
f"Input names ({inputs}) should not be provided when "
@@ -218,45 +218,46 @@ def extract_function_inputs_and_outputs(
218218
f"value {default}"
219219
)
220220
return_type = type_hints.get("return", ty.Any)
221-
if outputs and len(outputs) > 1:
222-
if return_type is not ty.Any:
223-
if ty.get_origin(return_type) is not tuple:
224-
raise ValueError(
225-
f"Multiple outputs specified ({outputs}) but non-tuple "
226-
f"return value {return_type}"
227-
)
228-
return_types = ty.get_args(return_type)
229-
if len(return_types) != len(outputs):
230-
raise ValueError(
231-
f"Length of the outputs ({outputs}) does not match that "
232-
f"of the return types ({return_types})"
233-
)
234-
output_types = dict(zip(outputs, return_types))
235-
else:
236-
output_types = {o: ty.Any for o in outputs}
237-
if isinstance(outputs, dict):
238-
for output_name, output in outputs.items():
239-
if isinstance(output, Out) and output.type is ty.Any:
240-
output.type = output_types[output_name]
221+
if outputs:
222+
if len(outputs) > 1:
223+
if return_type is not ty.Any:
224+
if ty.get_origin(return_type) is not tuple:
225+
raise ValueError(
226+
f"Multiple outputs specified ({outputs}) but non-tuple "
227+
f"return value {return_type}"
228+
)
229+
return_types = ty.get_args(return_type)
230+
if len(return_types) != len(outputs):
231+
raise ValueError(
232+
f"Length of the outputs ({outputs}) does not match that "
233+
f"of the return types ({return_types})"
234+
)
235+
output_types = dict(zip(outputs, return_types))
236+
else:
237+
output_types = {o: ty.Any for o in outputs}
238+
if isinstance(outputs, dict):
239+
for output_name, output in outputs.items():
240+
if isinstance(output, Out) and output.type is ty.Any:
241+
output.type = output_types[output_name]
242+
else:
243+
outputs = output_types
241244
else:
242-
outputs = output_types
243-
244-
elif outputs:
245-
if isinstance(outputs, dict):
246-
output_name, output = next(iter(outputs.items()))
247-
elif isinstance(outputs, list):
248-
output_name = outputs[0]
249-
output = ty.Any
250-
if isinstance(output, Out):
251-
if output.type is ty.Any:
252-
output.type = return_type
253-
elif output is ty.Any:
254-
output = return_type
255-
outputs = {output_name: output}
256-
elif return_type is not None:
257-
outputs = {"out": return_type}
245+
if isinstance(outputs, dict):
246+
output_name, output = next(iter(outputs.items()))
247+
elif isinstance(outputs, list):
248+
output_name = outputs[0]
249+
output = ty.Any
250+
if isinstance(output, Out):
251+
if output.type is ty.Any:
252+
output.type = return_type
253+
elif output is ty.Any:
254+
output = return_type
255+
outputs = {output_name: output}
258256
else:
259-
outputs = {}
257+
if return_type not in (None, type(None)):
258+
outputs = {"out": return_type}
259+
else:
260+
outputs = {}
260261
return inputs, outputs
261262

262263

pydra/compose/tests/test_python_fields.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,38 @@ def TestFunc(a: A):
424424

425425
outputs = TestFunc(a=A(x=7))()
426426
assert outputs.out == 7
427+
428+
429+
def test_no_outputs1():
430+
"""Test function tasks with object inputs"""
431+
432+
@python.define
433+
def TestFunc(a: A) -> None:
434+
pass
435+
436+
outputs = TestFunc(a=A(x=7))()
437+
assert len(outputs) == 0
438+
439+
440+
def test_no_outputs2():
441+
"""Test function tasks with object inputs"""
442+
443+
@python.define(outputs=[])
444+
def TestFunc(a: A):
445+
pass
446+
447+
outputs = TestFunc(a=A(x=7))()
448+
assert len(outputs) == 0
449+
450+
451+
def test_no_outputs_fail():
452+
"""Test function tasks with object inputs"""
453+
454+
@python.define(outputs=[])
455+
def TestFunc(a: A):
456+
return a
457+
458+
with pytest.raises(
459+
ValueError, match="Returns an output but no outputs are defined"
460+
):
461+
TestFunc(a=A(x=7))()

0 commit comments

Comments
 (0)