|
1 | 1 | import logging
|
2 | 2 | import re
|
3 | 3 | import traceback as tb
|
| 4 | +from collections.abc import Iterable |
4 | 5 | from pathlib import Path
|
5 | 6 |
|
| 7 | +import pytensor.misc.pkl_utils |
6 | 8 | from pytensor.compile.function.pfunc import pfunc
|
7 | 9 | from pytensor.compile.function.types import orig_function
|
| 10 | +from pytensor.compile.mode import Mode |
| 11 | +from pytensor.compile.profiling import ProfileStats |
| 12 | +from pytensor.graph import Variable |
8 | 13 |
|
9 | 14 |
|
10 | 15 | __all__ = ["types", "pfunc"]
|
|
15 | 20 |
|
16 | 21 | def function_dump(
|
17 | 22 | filename: str | Path,
|
18 |
| - inputs, |
19 |
| - outputs=None, |
20 |
| - mode=None, |
21 |
| - updates=None, |
22 |
| - givens=None, |
23 |
| - no_default_updates=False, |
24 |
| - accept_inplace=False, |
25 |
| - name=None, |
26 |
| - rebuild_strict=True, |
27 |
| - allow_input_downcast=None, |
28 |
| - profile=None, |
29 |
| - on_unused_input=None, |
| 23 | + inputs: Iterable[Variable], |
| 24 | + outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None, |
| 25 | + mode: str | Mode | None = None, |
| 26 | + updates: Iterable[tuple[Variable, Variable]] |
| 27 | + | dict[Variable, Variable] |
| 28 | + | None = None, |
| 29 | + givens: Iterable[tuple[Variable, Variable]] |
| 30 | + | dict[Variable, Variable] |
| 31 | + | None = None, |
| 32 | + no_default_updates: bool = False, |
| 33 | + accept_inplace: bool = False, |
| 34 | + name: str | None = None, |
| 35 | + rebuild_strict: bool = True, |
| 36 | + allow_input_downcast: bool | None = None, |
| 37 | + profile: bool | ProfileStats | None = None, |
| 38 | + on_unused_input: str | None = None, |
30 | 39 | extra_tag_to_remove: str | None = None,
|
31 | 40 | ):
|
32 | 41 | """
|
@@ -60,43 +69,44 @@ def function_dump(
|
60 | 69 | `['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
|
61 | 70 |
|
62 | 71 | """
|
63 |
| - filename = Path(filename) |
64 |
| - d = dict( |
65 |
| - inputs=inputs, |
66 |
| - outputs=outputs, |
67 |
| - mode=mode, |
68 |
| - updates=updates, |
69 |
| - givens=givens, |
70 |
| - no_default_updates=no_default_updates, |
71 |
| - accept_inplace=accept_inplace, |
72 |
| - name=name, |
73 |
| - rebuild_strict=rebuild_strict, |
74 |
| - allow_input_downcast=allow_input_downcast, |
75 |
| - profile=profile, |
76 |
| - on_unused_input=on_unused_input, |
77 |
| - ) |
78 |
| - with filename.open("wb") as f: |
79 |
| - import pytensor.misc.pkl_utils |
80 |
| - |
| 72 | + d = { |
| 73 | + "inputs": inputs, |
| 74 | + "outputs": outputs, |
| 75 | + "mode": mode, |
| 76 | + "updates": updates, |
| 77 | + "givens": givens, |
| 78 | + "no_default_updates": no_default_updates, |
| 79 | + "accept_inplace": accept_inplace, |
| 80 | + "name": name, |
| 81 | + "rebuild_strict": rebuild_strict, |
| 82 | + "allow_input_downcast": allow_input_downcast, |
| 83 | + "profile": profile, |
| 84 | + "on_unused_input": on_unused_input, |
| 85 | + } |
| 86 | + with Path(filename).open("wb") as f: |
81 | 87 | pickler = pytensor.misc.pkl_utils.StripPickler(
|
82 | 88 | f, protocol=-1, extra_tag_to_remove=extra_tag_to_remove
|
83 | 89 | )
|
84 | 90 | pickler.dump(d)
|
85 | 91 |
|
86 | 92 |
|
87 | 93 | def function(
|
88 |
| - inputs, |
89 |
| - outputs=None, |
90 |
| - mode=None, |
91 |
| - updates=None, |
92 |
| - givens=None, |
93 |
| - no_default_updates=False, |
94 |
| - accept_inplace=False, |
95 |
| - name=None, |
96 |
| - rebuild_strict=True, |
97 |
| - allow_input_downcast=None, |
98 |
| - profile=None, |
99 |
| - on_unused_input=None, |
| 94 | + inputs: Iterable[Variable], |
| 95 | + outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None, |
| 96 | + mode: str | Mode | None = None, |
| 97 | + updates: Iterable[tuple[Variable, Variable]] |
| 98 | + | dict[Variable, Variable] |
| 99 | + | None = None, |
| 100 | + givens: Iterable[tuple[Variable, Variable]] |
| 101 | + | dict[Variable, Variable] |
| 102 | + | None = None, |
| 103 | + no_default_updates: bool = False, |
| 104 | + accept_inplace: bool = False, |
| 105 | + name: str | None = None, |
| 106 | + rebuild_strict: bool = True, |
| 107 | + allow_input_downcast: bool | None = None, |
| 108 | + profile: bool | ProfileStats | None = None, |
| 109 | + on_unused_input: str | None = None, |
100 | 110 | ):
|
101 | 111 | """
|
102 | 112 | Return a :class:`callable object <pytensor.compile.function.types.Function>`
|
|
0 commit comments