@@ -28,7 +28,20 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2828 # across the entire pytensor graph, not
2929 # just the subgraph
3030 generator = unique_name_generator (["torch_linker" ])
31- built_kwargs = {"unique_name" : generator , "linker" : self , ** kwargs }
31+
32+ # Ensure that torch is aware of the generated
33+ # code so we can compile without graph breaks
34+ def conversion_func_register (* args , ** kwargs ):
35+ functor = pytorch_funcify (* args , ** kwargs )
36+ name = kwargs ["unique_name" ](functor )
37+ self .gen_functors .append ((f"_{ name } " , functor ))
38+ return functor
39+
40+ built_kwargs = {
41+ "unique_name" : generator ,
42+ "conversion_func" : conversion_func_register ,
43+ ** kwargs ,
44+ }
3245 return pytorch_funcify (
3346 fgraph , input_storage = input_storage , storage_map = storage_map , ** built_kwargs
3447 )
@@ -37,6 +50,16 @@ def jit_compile(self, fn):
3750 import torch
3851
3952 class wrapper :
53+ """
54+ Pytorch would fail compiling our method when trying
55+ to resolve some of the methods returned from dispatch
56+ calls. We want to be careful to not leak the methods,
57+ so this class just holds them and provisions the expected
58+ location accordingly
59+
60+ https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
61+ """
62+
4063 def __init__ (self , fn , gen_functors ):
4164 self .fn = torch .compile (fn )
4265 self .gen_functors = copy .copy (gen_functors )
@@ -46,18 +69,18 @@ def __call__(self, *args, **kwargs):
4669
4770 # set attrs
4871 for n , fn in self .gen_functors :
49- setattr (pytensor .link .utils , n , fn )
72+ setattr (pytensor .link .utils , n [ 1 :] , fn )
5073
5174 res = self .fn (* args , ** kwargs )
5275
5376 # unset attrs
5477 for n , _ in self .gen_functors :
55- delattr (pytensor .link .utils , n )
78+ if getattr (pytensor .link .utils , n , False ):
79+ delattr (pytensor .link .utils , n [1 :])
5680
5781 return res
5882
5983 def __del__ (self ):
60- print ("del" )
6184 del self .gen_functors
6285
6386 res = wrapper (fn , self .gen_functors )
@@ -71,6 +94,3 @@ def create_thunk_inputs(self, storage_map):
7194 thunk_inputs .append (sinput )
7295
7396 return thunk_inputs
74-
75- def record_fn (self , name , fn ):
76- self .gen_functors .append ((name , fn ))
0 commit comments