Skip to content

Commit 5abd32d

Browse files
committed
Changes
1 parent 4d5b34b commit 5abd32d

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import mlx.core as mx
2+
3+
from pytensor.link.mlx.dispatch import mlx_funcify
4+
from pytensor.tensor.blockwise import Blockwise
5+
6+
@mlx_funcify.register(Blockwise)
7+
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
8+
core_f = mlx_funcify(op.core_op)
9+
batched_f = core_f
10+
for _ in range(op.batch_ndim(node)):
11+
batched_f = mx.vmap(batched_f)
12+
13+
def wrapped_blockwise_f(*inputs):
14+
return batched_f(*inputs)
15+
16+
return wrapped_blockwise_f

pytensor/link/mlx/dispatch/signal/conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import mlx.core as mx
2+
13
from pytensor.link.mlx.dispatch import mlx_funcify
24
from pytensor.tensor.signal.conv import Conv1d
35

4-
import mlx.core as mx
5-
66

77
@mlx_funcify.register(Conv1d)
88
def mlx_funcify_Conv1d(op, node, **kwargs):

0 commit comments

Comments
 (0)