Skip to content

Commit b083fb9

Browse files
committed
Type function arguments
1 parent 084fbf1 commit b083fb9

File tree

1 file changed

+52
-42
lines changed

1 file changed

+52
-42
lines changed

pytensor/compile/function/__init__.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import logging
22
import re
33
import traceback as tb
4+
from collections.abc import Iterable
45
from pathlib import Path
56

7+
import pytensor.misc.pkl_utils
68
from pytensor.compile.function.pfunc import pfunc
79
from pytensor.compile.function.types import orig_function
10+
from pytensor.compile.mode import Mode
11+
from pytensor.compile.profiling import ProfileStats
12+
from pytensor.graph import Variable
813

914

1015
__all__ = ["types", "pfunc"]
@@ -15,18 +20,22 @@
1520

1621
def function_dump(
1722
filename: str | Path,
18-
inputs,
19-
outputs=None,
20-
mode=None,
21-
updates=None,
22-
givens=None,
23-
no_default_updates=False,
24-
accept_inplace=False,
25-
name=None,
26-
rebuild_strict=True,
27-
allow_input_downcast=None,
28-
profile=None,
29-
on_unused_input=None,
23+
inputs: Iterable[Variable],
24+
outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None,
25+
mode: str | Mode | None = None,
26+
updates: Iterable[tuple[Variable, Variable]]
27+
| dict[Variable, Variable]
28+
| None = None,
29+
givens: Iterable[tuple[Variable, Variable]]
30+
| dict[Variable, Variable]
31+
| None = None,
32+
no_default_updates: bool = False,
33+
accept_inplace: bool = False,
34+
name: str | None = None,
35+
rebuild_strict: bool = True,
36+
allow_input_downcast: bool | None = None,
37+
profile: bool | ProfileStats | None = None,
38+
on_unused_input: str | None = None,
3039
extra_tag_to_remove: str | None = None,
3140
):
3241
"""
@@ -60,43 +69,44 @@ def function_dump(
6069
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
6170
6271
"""
63-
filename = Path(filename)
64-
d = dict(
65-
inputs=inputs,
66-
outputs=outputs,
67-
mode=mode,
68-
updates=updates,
69-
givens=givens,
70-
no_default_updates=no_default_updates,
71-
accept_inplace=accept_inplace,
72-
name=name,
73-
rebuild_strict=rebuild_strict,
74-
allow_input_downcast=allow_input_downcast,
75-
profile=profile,
76-
on_unused_input=on_unused_input,
77-
)
78-
with filename.open("wb") as f:
79-
import pytensor.misc.pkl_utils
80-
72+
d = {
73+
"inputs": inputs,
74+
"outputs": outputs,
75+
"mode": mode,
76+
"updates": updates,
77+
"givens": givens,
78+
"no_default_updates": no_default_updates,
79+
"accept_inplace": accept_inplace,
80+
"name": name,
81+
"rebuild_strict": rebuild_strict,
82+
"allow_input_downcast": allow_input_downcast,
83+
"profile": profile,
84+
"on_unused_input": on_unused_input,
85+
}
86+
with Path(filename).open("wb") as f:
8187
pickler = pytensor.misc.pkl_utils.StripPickler(
8288
f, protocol=-1, extra_tag_to_remove=extra_tag_to_remove
8389
)
8490
pickler.dump(d)
8591

8692

8793
def function(
88-
inputs,
89-
outputs=None,
90-
mode=None,
91-
updates=None,
92-
givens=None,
93-
no_default_updates=False,
94-
accept_inplace=False,
95-
name=None,
96-
rebuild_strict=True,
97-
allow_input_downcast=None,
98-
profile=None,
99-
on_unused_input=None,
94+
inputs: Iterable[Variable],
95+
outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None,
96+
mode: str | Mode | None = None,
97+
updates: Iterable[tuple[Variable, Variable]]
98+
| dict[Variable, Variable]
99+
| None = None,
100+
givens: Iterable[tuple[Variable, Variable]]
101+
| dict[Variable, Variable]
102+
| None = None,
103+
no_default_updates: bool = False,
104+
accept_inplace: bool = False,
105+
name: str | None = None,
106+
rebuild_strict: bool = True,
107+
allow_input_downcast: bool | None = None,
108+
profile: bool | ProfileStats | None = None,
109+
on_unused_input: str | None = None,
100110
):
101111
"""
102112
Return a :class:`callable object <pytensor.compile.function.types.Function>`

0 commit comments

Comments
 (0)