Skip to content

Commit 9b4059d

Browse files
author
Ian Schweer
committed
Use explicit closure lifetime and tracking
1 parent 33a4d48 commit 9b4059d

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.compiler
77

8+
import pytensor.link.utils
89
from pytensor.compile import PYTORCH
910
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.compile.ops import DeepCopyOp
@@ -56,9 +57,15 @@ def pytorch_funcify_FunctionGraph(
5657
fgraph_name="pytorch_funcified_fgraph",
5758
**kwargs,
5859
):
60+
def conversion_func_register(*args, **kwargs):
61+
functor = pytorch_funcify(*args, **kwargs)
62+
module = pytensor.link.utils
63+
setattr(module, kwargs["unique_name"](functor), functor)
64+
return functor
65+
5966
return fgraph_to_python(
6067
fgraph,
61-
pytorch_funcify,
68+
conversion_func_register,
6269
type_conversion_fn=pytorch_typify,
6370
fgraph_name=fgraph_name,
6471
**kwargs,
@@ -173,7 +180,6 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
173180

174181
# Apply inner rewrites
175182
PYTORCH.optimizer(op.fgraph)
176-
177183
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
178184
# Disable one step inlining to prevent torch from trying to import local functions
179185
# defined in `pytorch_funcify`

pytensor/link/pytorch/linker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.graph.basic import Variable
44
from pytensor.link.basic import JITLinker
5+
from pytensor.link.utils import unique_name_generator
56

67

78
class PytorchLinker(JITLinker):
@@ -18,8 +19,10 @@ def output_filter(self, var: Variable, out: Any) -> Any:
1819
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1920
from pytensor.link.pytorch.dispatch import pytorch_funcify
2021

22+
generator = unique_name_generator(["torch_linker"])
23+
built_kwargs = {"unique_name": generator, **kwargs}
2124
return pytorch_funcify(
22-
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
25+
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
2326
)
2427

2528
def jit_compile(self, fn):

pytensor/link/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,8 @@ def fgraph_to_python(
704704
get_name_for_object
705705
A function used to provide names for the objects referenced within the
706706
generated function.
707+
unique_name
708+
A function to make random function names for generated code (access through kwargs)
707709
squeeze_output
708710
If the `FunctionGraph` has only one output and this option is
709711
``True``, return the single output instead of a tuple with the output.
@@ -717,7 +719,8 @@ def fgraph_to_python(
717719
if storage_map is None:
718720
storage_map = {}
719721

720-
unique_name = unique_name_generator([fgraph_name])
722+
unique_name = kwargs.get("unique_name", unique_name_generator([fgraph_name]))
723+
kwargs["unique_name"] = unique_name
721724

722725
if global_env is None:
723726
global_env = {}

tests/link/pytorch/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_pytorch_OpFromGraph():
335335
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
336336

337337
o1, o2 = ofg_2(y, z)
338-
out = ofg_1(x, o1) + o2
338+
out = ofg_1(x, o1) / o2
339339

340340
xv = np.ones((2, 2), dtype=config.floatX)
341341
yv = np.ones((2, 2), dtype=config.floatX) * 3

0 commit comments

Comments
 (0)