Skip to content

Commit e690bff

Browse files
committed
modify the pytorch jit
1 parent a9ecad0 commit e690bff

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

pytensor/link/mlx/linker.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,41 @@ def conversion_func_register(*args, **kwargs):
6666
)
6767

6868
def jit_compile(self, fn):
69-
"""JIT compile an MLX function.
69+
import mlx.core as mx
7070

71-
Parameters
72-
----------
73-
fn : callable
74-
The function to compile
71+
from pytensor.link.mlx.dispatch import mlx_typify
7572

76-
Returns
77-
-------
78-
callable
79-
The compiled function
80-
"""
81-
import mlx.core as mx
73+
class wrapper:
74+
def __init__(self, fn, gen_functors):
75+
self.fn = mx.compile(fn)
76+
self.gen_functors = gen_functors.copy()
77+
78+
def __call__(self, *inputs, **kwargs):
79+
import pytensor.link.utils
80+
81+
# set attrs
82+
for n, fn in self.gen_functors:
83+
setattr(pytensor.link.utils, n[1:], fn)
84+
85+
# MLX doesn't support np.ndarray as input
86+
outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs)
87+
88+
return outs
89+
90+
# unset attrs
91+
for n, _ in self.gen_functors:
92+
if getattr(pytensor.link.utils, n[1:], False):
93+
delattr(pytensor.link.utils, n[1:])
94+
95+
return tuple(out.cpu().numpy() for out in outs)
96+
97+
def __del__(self):
98+
del self.gen_functors
99+
100+
inner_fn = wrapper(fn, self.gen_functors)
101+
self.gen_functors = []
82102

83-
return mx.compile(fn)
103+
return inner_fn
84104

85105
def create_thunk_inputs(self, storage_map):
86106
"""Create inputs for the MLX thunk.

0 commit comments

Comments
 (0)