@@ -37,6 +37,16 @@ def jit_compile(self, fn):
3737 import torch
3838
3939 class wrapper :
40+ """
41+ Pytorch would fail compiling our method when trying
42+ to resolve some of the methods returned from dispatch
43+ calls. We want to be careful to not leak the methods,
44+ so this class just holds them and provisions the expected
45+ location accordingly
46+
47+ https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
48+ """
49+
4050 def __init__ (self , fn , gen_functors ):
4151 self .fn = torch .compile (fn )
4252 self .gen_functors = copy .copy (gen_functors )
@@ -46,18 +56,18 @@ def __call__(self, *args, **kwargs):
4656
4757 # set attrs
4858 for n , fn in self .gen_functors :
49- setattr (pytensor .link .utils , n , fn )
59+ setattr (pytensor .link .utils , n [ 1 :] , fn )
5060
5161 res = self .fn (* args , ** kwargs )
5262
5363 # unset attrs
5464 for n , _ in self .gen_functors :
55- delattr (pytensor .link .utils , n )
65+ if getattr (pytensor .link .utils , n , False ):
66+ delattr (pytensor .link .utils , n [1 :])
5667
5768 return res
5869
5970 def __del__ (self ):
60- print ("del" )
6171 del self .gen_functors
6272
6373 res = wrapper (fn , self .gen_functors )
@@ -73,4 +83,4 @@ def create_thunk_inputs(self, storage_map):
7383 return thunk_inputs
7484
7585 def record_fn (self , name , fn ):
76- self .gen_functors .append ((name , fn ))
86+ self .gen_functors .append ((f"_ { name } " , fn ))
0 commit comments