Skip to content

Commit d7b570c

Browse files
author
Ian Schweer
committed
PR comments
1 parent 270f231 commit d7b570c

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ def pytorch_funcify_FunctionGraph(
5656
fgraph_name="pytorch_funcified_fgraph",
5757
**kwargs,
5858
):
59-
# Ensure that torch is aware of the leaf nodes
60-
# of generated code so we can compile code
61-
# without graph breaks
59+
# Ensure that torch is aware of the generated
60+
# code so we can compile without graph breaks
6261
def conversion_func_register(*args, **kwargs):
6362
functor = pytorch_funcify(*args, **kwargs)
64-
kwargs["linker"].record_fn(kwargs["unique_name"](functor), functor)
63+
64+
# The linker could be missing cause we're only using
65+
# function graph
66+
if {"linker", "unique_name"}.issubset(set(kwargs.keys())):
67+
kwargs["linker"].record_fn(kwargs["unique_name"](functor), functor)
6568
return functor
6669

6770
return fgraph_to_python(

pytensor/link/pytorch/linker.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ def jit_compile(self, fn):
3737
import torch
3838

3939
class wrapper:
40+
"""
41+
Pytorch would fail compiling our method when trying
42+
to resolve some of the methods returned from dispatch
43+
calls. We want to be careful to not leak the methods,
44+
so this class just holds them and provisions the expected
45+
location accordingly
46+
47+
https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
48+
"""
49+
4050
def __init__(self, fn, gen_functors):
4151
self.fn = torch.compile(fn)
4252
self.gen_functors = copy.copy(gen_functors)
@@ -46,18 +56,18 @@ def __call__(self, *args, **kwargs):
4656

4757
# set attrs
4858
for n, fn in self.gen_functors:
49-
setattr(pytensor.link.utils, n, fn)
59+
setattr(pytensor.link.utils, n[1:], fn)
5060

5161
res = self.fn(*args, **kwargs)
5262

5363
# unset attrs
5464
for n, _ in self.gen_functors:
55-
delattr(pytensor.link.utils, n)
65+
if getattr(pytensor.link.utils, n, False):
66+
delattr(pytensor.link.utils, n[1:])
5667

5768
return res
5869

5970
def __del__(self):
60-
print("del")
6171
del self.gen_functors
6272

6373
res = wrapper(fn, self.gen_functors)
@@ -73,4 +83,4 @@ def create_thunk_inputs(self, storage_map):
7383
return thunk_inputs
7484

7585
def record_fn(self, name, fn):
76-
self.gen_functors.append((name, fn))
86+
self.gen_functors.append((f"_{name}", fn))

pytensor/link/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ def fgraph_to_python(
673673
local_env: dict[Any, Any] | None = None,
674674
get_name_for_object: Callable[[Any], str] = get_name_for_object,
675675
squeeze_output: bool = False,
676+
unique_name: Callable | None = None,
676677
**kwargs,
677678
) -> Callable:
678679
"""Convert a `FunctionGraph` into a regular Python function.
@@ -705,7 +706,7 @@ def fgraph_to_python(
705706
A function used to provide names for the objects referenced within the
706707
generated function.
707708
unique_name
708-
A function to make random function names for generated code (access through kwargs)
709+
A function to make random function names for generated code
709710
squeeze_output
710711
If the `FunctionGraph` has only one output and this option is
711712
``True``, return the single output instead of a tuple with the output.
@@ -719,7 +720,10 @@ def fgraph_to_python(
719720
if storage_map is None:
720721
storage_map = {}
721722

722-
unique_name = kwargs.get("unique_name", unique_name_generator([fgraph_name]))
723+
if not unique_name:
724+
unique_name = unique_name_generator([fgraph_name])
725+
726+
# make sure we plumb this through
723727
kwargs["unique_name"] = unique_name
724728

725729
if global_env is None:

0 commit comments

Comments
 (0)