Skip to content

Commit a73091e

Browse files
Streamline Blockwise impl
1 parent cd2d26e commit a73091e

File tree

1 file changed

+4
-17
lines changed

1 file changed

+4
-17
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,10 @@ def funcify_Blockwise(op: Blockwise, node, **kwargs):
1515

1616
# 4) Handle case where no vectorization is needed
1717
if n_batch == 0:
18-
19-
def blockwise_fun(*inputs):
20-
return core_f(*inputs)
21-
22-
return blockwise_fun
18+
return core_f
2319

2420
# 5) Vectorize using mx.vmap over any batched inputs
25-
in_axes = []
21+
in_axes: list[int | None] = []
2622
for inp, sig in zip(node.inputs, op.inputs_sig):
2723
batch_ndim = inp.type.ndim - len(sig)
2824
if batch_ndim == 0:
@@ -34,15 +30,6 @@ def blockwise_fun(*inputs):
3430
in_axes.append(0 if not all(batch_bcast) else None)
3531

3632
if not any(axis == 0 for axis in in_axes):
33+
return core_f
3734

38-
def blockwise_fun(*inputs):
39-
return core_f(*inputs)
40-
41-
return blockwise_fun
42-
43-
blockwise_f = mx.vmap(core_f, in_axes=tuple(in_axes))
44-
45-
def blockwise_fun(*inputs):
46-
return blockwise_f(*inputs)
47-
48-
return blockwise_fun
35+
return mx.vmap(core_f, in_axes=tuple(in_axes))

0 commit comments

Comments
 (0)