@@ -66,21 +66,41 @@ def conversion_func_register(*args, **kwargs):
6666 )
6767
6868 def jit_compile (self , fn ):
69- """JIT compile an MLX function.
69+ import mlx . core as mx
7070
71- Parameters
72- ----------
73- fn : callable
74- The function to compile
71+ from pytensor .link .mlx .dispatch import mlx_typify
7572
76- Returns
77- -------
78- callable
79- The compiled function
80- """
81- import mlx .core as mx
73+ class wrapper :
74+ def __init__ (self , fn , gen_functors ):
75+ self .fn = mx .compile (fn )
76+ self .gen_functors = gen_functors .copy ()
77+
78+ def __call__ (self , * inputs , ** kwargs ):
79+ import pytensor .link .utils
80+
81+ # set attrs
82+ for n , fn in self .gen_functors :
83+ setattr (pytensor .link .utils , n [1 :], fn )
84+
85+ # MLX doesn't support np.ndarray as input
86+ outs = self .fn (* (mlx_typify (inp ) for inp in inputs ), ** kwargs )
87+
88+ return outs
89+
90+ # unset attrs
91+ for n , _ in self .gen_functors :
92+ if getattr (pytensor .link .utils , n [1 :], False ):
93+ delattr (pytensor .link .utils , n [1 :])
94+
95+ return tuple (out .cpu ().numpy () for out in outs )
96+
97+ def __del__ (self ):
98+ del self .gen_functors
99+
100+ inner_fn = wrapper (fn , self .gen_functors )
101+ self .gen_functors = []
82102
83- return mx . compile ( fn )
103+ return inner_fn
84104
85105 def create_thunk_inputs (self , storage_map ):
86106 """Create inputs for the MLX thunk.
0 commit comments