@@ -9,19 +9,6 @@ def __init__(self, *args, **kwargs):
99 super ().__init__ (* args , ** kwargs )
1010 self .gen_functors = []
1111
12- def input_filter (self , inp ):
13- from pytensor .link .pytorch .dispatch import pytorch_typify
14-
15- return pytorch_typify (inp )
16-
17- def output_filter (self , var , out ):
18- from torch import is_tensor
19-
20- if is_tensor (out ):
21- return out .cpu ()
22- else :
23- return out
24-
2512 def fgraph_convert (self , fgraph , input_storage , storage_map , ** kwargs ):
2613 from pytensor .link .pytorch .dispatch import pytorch_funcify
2714
@@ -67,34 +54,30 @@ def __init__(self, fn, gen_functors):
6754 self .fn = torch .compile (fn )
6855 self .gen_functors = gen_functors .copy ()
6956
70- def __call__ (self , * args , ** kwargs ):
57+ def __call__ (self , * inputs , ** kwargs ):
7158 import pytensor .link .utils
7259
7360 # set attrs
7461 for n , fn in self .gen_functors :
7562 setattr (pytensor .link .utils , n [1 :], fn )
7663
77- res = self .fn (* args , ** kwargs )
64+ # Torch does not accept numpy inputs and may return GPU objects
65+ outs = self .fn (* (pytorch_typify (inp ) for inp in inputs ), ** kwargs )
7866
7967 # unset attrs
8068 for n , _ in self .gen_functors :
8169 if getattr (pytensor .link .utils , n [1 :], False ):
8270 delattr (pytensor .link .utils , n [1 :])
8371
84- return res
72+ return tuple ( out . cpu (). numpy () for out in outs )
8573
8674 def __del__ (self ):
8775 del self .gen_functors
8876
8977 inner_fn = wrapper (fn , self .gen_functors )
9078 self .gen_functors = []
9179
92- # Torch does not accept numpy inputs and may return GPU objects
93- def create_outputs (* inputs , inner_fn = inner_fn ):
94- outs = inner_fn (* (pytorch_typify (inp ) for inp in inputs ))
95- return tuple (out .cpu ().numpy () for out in outs )
96-
97- return create_outputs
80+ return inner_fn
9881
9982 def create_thunk_inputs (self , storage_map ):
10083 thunk_inputs = []
0 commit comments