1+ import copy
12from typing import Any
23
34from pytensor .graph .basic import Variable
89class PytorchLinker (JITLinker ):
910 """A `Linker` that compiles NumPy-based operations using torch.compile."""
1011
12+ def __init__ (self , * args , ** kwargs ):
13+ super ().__init__ (* args , ** kwargs )
14+ self .gen_functors = []
15+
1116 def input_filter (self , inp : Any ) -> Any :
1217 from pytensor .link .pytorch .dispatch import pytorch_typify
1318
@@ -23,15 +28,41 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2328 # across the entire pytensor graph, not
2429 # just the subgraph
2530 generator = unique_name_generator (["torch_linker" ])
26- built_kwargs = {"unique_name" : generator , ** kwargs }
31+ built_kwargs = {"unique_name" : generator , "linker" : self , ** kwargs }
2732 return pytorch_funcify (
2833 fgraph , input_storage = input_storage , storage_map = storage_map , ** built_kwargs
2934 )
3035
3136 def jit_compile (self , fn ):
3237 import torch
3338
34- return torch .compile (fn )
39+ class wrapper :
40+ def __init__ (self , fn , gen_functors ):
41+ self .fn = torch .compile (fn )
42+ self .gen_functors = copy .copy (gen_functors )
43+
44+ def __call__ (self , * args , ** kwargs ):
45+ import pytensor .link .utils
46+
47+ # set attrs
48+ for n , fn in self .gen_functors :
49+ setattr (pytensor .link .utils , n , fn )
50+
51+ res = self .fn (* args , ** kwargs )
52+
53+ # unset attrs
54+ for n , _ in self .gen_functors :
55+ delattr (pytensor .link .utils , n )
56+
57+ return res
58+
59+ def __del__ (self ):
60+ print ("del" )
61+ del self .gen_functors
62+
63+ res = wrapper (fn , self .gen_functors )
64+ self .gen_functors = []
65+ return res
3566
3667 def create_thunk_inputs (self , storage_map ):
3768 thunk_inputs = []
@@ -40,3 +71,6 @@ def create_thunk_inputs(self, storage_map):
4071 thunk_inputs .append (sinput )
4172
4273 return thunk_inputs
74+
75+ def record_fn (self , name , fn ):
76+ self .gen_functors .append ((name , fn ))
0 commit comments