Skip to content

Commit f3e74ec

Browse files
author
Ian Schweer
committed
Clean up after ourselves
1 parent c210fcc commit f3e74ec

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import torch.compiler
77

8-
import pytensor.link.utils
98
from pytensor.compile import PYTORCH
109
from pytensor.compile.builders import OpFromGraph
1110
from pytensor.compile.ops import DeepCopyOp
@@ -62,8 +61,7 @@ def pytorch_funcify_FunctionGraph(
6261
# without graph breaks
6362
def conversion_func_register(*args, **kwargs):
6463
functor = pytorch_funcify(*args, **kwargs)
65-
module = pytensor.link.utils
66-
setattr(module, kwargs["unique_name"](functor), functor)
64+
kwargs["linker"].record_fn(kwargs["unique_name"](functor), functor)
6765
return functor
6866

6967
return fgraph_to_python(

pytensor/link/pytorch/linker.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from typing import Any
23

34
from pytensor.graph.basic import Variable
@@ -8,6 +9,10 @@
89
class PytorchLinker(JITLinker):
910
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
1011

12+
def __init__(self, *args, **kwargs):
13+
super().__init__(*args, **kwargs)
14+
self.gen_functors = []
15+
1116
def input_filter(self, inp: Any) -> Any:
1217
from pytensor.link.pytorch.dispatch import pytorch_typify
1318

@@ -23,15 +28,41 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2328
# across the entire pytensor graph, not
2429
# just the subgraph
2530
generator = unique_name_generator(["torch_linker"])
26-
built_kwargs = {"unique_name": generator, **kwargs}
31+
built_kwargs = {"unique_name": generator, "linker": self, **kwargs}
2732
return pytorch_funcify(
2833
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
2934
)
3035

3136
def jit_compile(self, fn):
3237
import torch
3338

34-
return torch.compile(fn)
39+
class wrapper:
40+
def __init__(self, fn, gen_functors):
41+
self.fn = torch.compile(fn)
42+
self.gen_functors = copy.copy(gen_functors)
43+
44+
def __call__(self, *args, **kwargs):
45+
import pytensor.link.utils
46+
47+
# set attrs
48+
for n, fn in self.gen_functors:
49+
setattr(pytensor.link.utils, n, fn)
50+
51+
res = self.fn(*args, **kwargs)
52+
53+
# unset attrs
54+
for n, _ in self.gen_functors:
55+
delattr(pytensor.link.utils, n)
56+
57+
return res
58+
59+
def __del__(self):
60+
print("del")
61+
del self.gen_functors
62+
63+
res = wrapper(fn, self.gen_functors)
64+
self.gen_functors = []
65+
return res
3566

3667
def create_thunk_inputs(self, storage_map):
3768
thunk_inputs = []
@@ -40,3 +71,6 @@ def create_thunk_inputs(self, storage_map):
4071
thunk_inputs.append(sinput)
4172

4273
return thunk_inputs
74+
75+
def record_fn(self, name, fn):
76+
self.gen_functors.append((name, fn))

0 commit comments

Comments
 (0)