Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calling without arguments now raises an error, because dags needs the external inputs:"
]
"source": "Calling without arguments raises an `InvalidFunctionArgumentsError`, telling you exactly which inputs are missing:"
},
{
"cell_type": "code",
Expand Down
13 changes: 7 additions & 6 deletions src/dags/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,6 @@ def _create_concatenated_function(
args = arglist
return_annotation = inspect.Parameter.empty

@with_signature(
args=args,
enforce=enforce_signature,
return_annotation=return_annotation,
)
def concatenated(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
results = {**dict(zip(arglist, args, strict=False)), **kwargs}
for name, info in execution_info.items():
Expand All @@ -633,7 +628,13 @@ def concatenated(*args: Any, **kwargs: Any) -> tuple[Any, ...]:

return tuple(results[target] for target in targets)

return concatenated
concatenated.__name__ = "The concatenated function"
return with_signature(
concatenated,
args=args,
enforce=enforce_signature,
return_annotation=return_annotation,
)


def _infer_aggregator_return_type(
Expand Down
24 changes: 21 additions & 3 deletions src/dags/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def wrapper_with_signature(*args: P.args, **kwargs: P.kwargs) -> R:
_fail_if_invalid_keyword_arguments(
present_kwargs, valid_kwargs, funcname
)
_fail_if_missing_arguments(
present_args, present_kwargs, valid_kwargs, funcname
)
return func(*args, **kwargs)

wrapper_with_signature.__signature__ = signature # ty: ignore[unresolved-attribute]
Expand All @@ -153,7 +156,7 @@ def _fail_if_too_many_positional_arguments(
) -> None:
if len(present_args) > len(argnames):
msg = (
f"{funcname}() takes {len(argnames)} positional arguments "
f"{funcname} takes {len(argnames)} positional arguments "
f"but {len(present_args)} were given"
)
raise InvalidFunctionArgumentsError(msg)
Expand All @@ -166,7 +169,7 @@ def _fail_if_duplicated_arguments(
if problematic:
s = "s" if len(problematic) >= 2 else "" # noqa: PLR2004
problem_str = ", ".join(list(problematic))
msg = f"{funcname}() got multiple values for argument{s} {problem_str}"
msg = f"{funcname} got multiple values for argument{s} {problem_str}"
raise InvalidFunctionArgumentsError(msg)


Expand All @@ -177,7 +180,22 @@ def _fail_if_invalid_keyword_arguments(
if problematic:
s = "s" if len(problematic) >= 2 else "" # noqa: PLR2004
problem_str = ", ".join(list(problematic))
msg = f"{funcname}() got unexpected keyword argument{s} {problem_str}"
msg = f"{funcname} got unexpected keyword argument{s} {problem_str}"
raise InvalidFunctionArgumentsError(msg)


def _fail_if_missing_arguments(
present_args: set[str],
present_kwargs: set[str],
required_args: set[str],
funcname: str,
) -> None:
provided = present_args | present_kwargs
missing = required_args - provided
if missing:
s = "s" if len(missing) >= 2 else "" # noqa: PLR2004
missing_str = ", ".join(sorted(missing))
msg = f"{funcname} is missing required argument{s}: {missing_str}"
raise InvalidFunctionArgumentsError(msg)


Expand Down
28 changes: 26 additions & 2 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def f(*args, **kwargs):

with pytest.raises(
InvalidFunctionArgumentsError,
match=r"f\(\) got multiple values for argument b",
match=r"f got multiple values for argument b",
):
f(1, 2, b=3)

Expand All @@ -154,7 +154,7 @@ def f(*args, **kwargs):

with pytest.raises(
InvalidFunctionArgumentsError,
match=r"f\(\) got unexpected keyword argument d",
match=r"f got unexpected keyword argument d",
):
f(1, 2, d=4)

Expand Down Expand Up @@ -207,6 +207,30 @@ def f(d: int, e: float, *, f: bool) -> float:
}


def test_with_signature_decorator_missing_arguments() -> None:
@with_signature(args=["a", "b"], kwargs=["c"])
def f(*args, **kwargs):
return sum(args) + sum(kwargs.values())

with pytest.raises(
InvalidFunctionArgumentsError,
match="missing required argument",
):
f(1)


def test_with_signature_decorator_missing_all_arguments() -> None:
@with_signature(args=["a", "b"], kwargs=["c"])
def f(*args, **kwargs):
return sum(args) + sum(kwargs.values())

with pytest.raises(
InvalidFunctionArgumentsError,
match=r"missing required arguments.*a.*b",
):
f()


def test_with_signature_invalid_args_type() -> None:
with pytest.raises(DagsError, match="Invalid type for arg"):

Expand Down