Skip to content

Commit c6f1dcb

Browse files
author
Ian Schweer
committed
PR comments
1 parent f3e74ec commit c6f1dcb

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,16 @@ def pytorch_funcify_FunctionGraph(
5454
fgraph,
5555
node=None,
5656
fgraph_name="pytorch_funcified_fgraph",
57+
conversion_func=pytorch_funcify,
5758
**kwargs,
5859
):
59-
# Ensure that torch is aware of the leaf nodes
60-
# of generated code so we can compile code
61-
# without graph breaks
62-
def conversion_func_register(*args, **kwargs):
63-
functor = pytorch_funcify(*args, **kwargs)
64-
kwargs["linker"].record_fn(kwargs["unique_name"](functor), functor)
65-
return functor
66-
60+
built_kwargs = {"conversion_func": conversion_func, **kwargs}
6761
return fgraph_to_python(
6862
fgraph,
69-
conversion_func_register,
63+
conversion_func,
7064
type_conversion_fn=pytorch_typify,
7165
fgraph_name=fgraph_name,
72-
**kwargs,
66+
**built_kwargs,
7367
)
7468

7569

pytensor/link/pytorch/linker.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,20 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2828
# across the entire pytensor graph, not
2929
# just the subgraph
3030
generator = unique_name_generator(["torch_linker"])
31-
built_kwargs = {"unique_name": generator, "linker": self, **kwargs}
31+
32+
# Ensure that torch is aware of the generated
33+
# code so we can compile without graph breaks
34+
def conversion_func_register(*args, **kwargs):
35+
functor = pytorch_funcify(*args, **kwargs)
36+
name = kwargs["unique_name"](functor)
37+
self.gen_functors.append((f"_{name}", functor))
38+
return functor
39+
40+
built_kwargs = {
41+
"unique_name": generator,
42+
"conversion_func": conversion_func_register,
43+
**kwargs,
44+
}
3245
return pytorch_funcify(
3346
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
3447
)
@@ -37,6 +50,16 @@ def jit_compile(self, fn):
3750
import torch
3851

3952
class wrapper:
53+
"""
54+
Pytorch would fail compiling our method when trying
55+
to resolve some of the methods returned from dispatch
56+
calls. We want to be careful to not leak the methods,
57+
so this class just holds them and provisions the expected
58+
location accordingly
59+
60+
https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
61+
"""
62+
4063
def __init__(self, fn, gen_functors):
4164
self.fn = torch.compile(fn)
4265
self.gen_functors = copy.copy(gen_functors)
@@ -46,18 +69,18 @@ def __call__(self, *args, **kwargs):
4669

4770
# set attrs
4871
for n, fn in self.gen_functors:
49-
setattr(pytensor.link.utils, n, fn)
72+
setattr(pytensor.link.utils, n[1:], fn)
5073

5174
res = self.fn(*args, **kwargs)
5275

5376
# unset attrs
5477
for n, _ in self.gen_functors:
55-
delattr(pytensor.link.utils, n)
78+
if getattr(pytensor.link.utils, n, False):
79+
delattr(pytensor.link.utils, n[1:])
5680

5781
return res
5882

5983
def __del__(self):
60-
print("del")
6184
del self.gen_functors
6285

6386
res = wrapper(fn, self.gen_functors)
@@ -71,6 +94,3 @@ def create_thunk_inputs(self, storage_map):
7194
thunk_inputs.append(sinput)
7295

7396
return thunk_inputs
74-
75-
def record_fn(self, name, fn):
76-
self.gen_functors.append((name, fn))

pytensor/link/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ def fgraph_to_python(
675675
local_env: dict[Any, Any] | None = None,
676676
get_name_for_object: Callable[[Any], str] = get_name_for_object,
677677
squeeze_output: bool = False,
678+
unique_name: Callable | None = None,
678679
**kwargs,
679680
) -> Callable:
680681
"""Convert a `FunctionGraph` into a regular Python function.
@@ -707,7 +708,7 @@ def fgraph_to_python(
707708
A function used to provide names for the objects referenced within the
708709
generated function.
709710
unique_name
710-
A function to make random function names for generated code (access through kwargs)
711+
A function to make random function names for generated code
711712
squeeze_output
712713
If the `FunctionGraph` has only one output and this option is
713714
``True``, return the single output instead of a tuple with the output.
@@ -721,7 +722,10 @@ def fgraph_to_python(
721722
if storage_map is None:
722723
storage_map = {}
723724

724-
unique_name = kwargs.get("unique_name", unique_name_generator([fgraph_name]))
725+
if not unique_name:
726+
unique_name = unique_name_generator([fgraph_name])
727+
728+
# make sure we plumb this through
725729
kwargs["unique_name"] = unique_name
726730

727731
if global_env is None:

0 commit comments

Comments
 (0)