Skip to content

Commit ca82a5c

Browse files
committed
enable passing list[Arg] to 'inputs' in decorator
1 parent aa35926 commit ca82a5c

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

pydra/compose/base/helpers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,15 @@ def extract_function_inputs_and_outputs(
185185
input_defaults[p.name] = p.default
186186
if inputs is not None:
187187
if not isinstance(inputs, dict):
188-
raise ValueError(
189-
f"Input names ({inputs}) should not be provided when "
190-
"wrapping/decorating a function as "
191-
)
188+
if non_named_args := [
189+
i for i in inputs if not isinstance(i, Arg) or i.name is None
190+
]:
191+
raise ValueError(
192+
"Only named Arg objects should be provided as inputs (i.e. not names or "
193+
"other objects should not be provided when wrapping/decorating a "
194+
f"function: found {non_named_args} when wrapping/decorating {function!r}"
195+
)
196+
inputs = {i.name: i for i in inputs}
192197
if not has_varargs:
193198
if unrecognised := set(inputs) - set(input_types):
194199
raise ValueError(

pydra/compose/tests/test_python_fields.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ def func(function: ty.Callable) -> ty.Callable:
4141
return function
4242

4343

44+
def test_function_arg_fail2():
45+
46+
with pytest.raises(
47+
ValueError, match="Only named Arg objects should be provided as inputs"
48+
):
49+
50+
@python.define(inputs=[python.arg(help="an int")])
51+
def func(a: int) -> int:
52+
return a * 2
53+
54+
55+
def test_function_arg_add_help():
56+
57+
@python.define(inputs=[python.arg(name="a", help="an int")])
58+
def func(a: int) -> int:
59+
return a * 2
60+
61+
assert task_fields(func).a.help == "an int"
62+
63+
4464
def test_interface_wrap_function_with_default():
4565
def func(a: int, k: float = 2.0) -> float:
4666
"""Sample function with inputs and outputs"""

pydra/utils/general.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
logger = logging.getLogger("pydra")
2323
if ty.TYPE_CHECKING:
24-
from pydra.compose.base import Task
24+
from pydra.compose.base import Task, Field # noqa
2525
from pydra.compose import workflow
2626

2727

0 commit comments

Comments
 (0)