Skip to content

Commit 294c271

Browse files
committed
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
1 parent e7cf10e commit 294c271

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,41 @@
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
44
from pytensor.tensor.blockwise import Blockwise
5+
from pytensor.tensor.signal.conv import Conv1d
56

7+
def blockwise_conv1d(op, node):
8+
if op.core_op.mode != "valid":
9+
raise NotImplementedError("Only 'valid' mode is supported for conv1d")
10+
batches_ndim = op.batch_ndim(node)
11+
if batches_ndim != 1:
12+
raise NotImplementedError("Only 1D batches are supported for conv1d")
13+
14+
_, kernel = node.inputs
15+
if not all(kernel.type.broadcastable[:batches_ndim]):
16+
raise NotImplementedError("Only 1D batches are supported for conv1d")
17+
18+
def inner_f(x, kernel):
19+
x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together
20+
b = x_reshaped.shape[1] #
21+
kernel_squeeze = kernel.reshape(-1)
22+
f = kernel_squeeze.shape[0] # Number of filters
23+
kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b))
24+
conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1)
25+
_, conv_shape, _ = conv_result.shape
26+
return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,))
27+
return inner_f
628

729
@mlx_funcify.register(Blockwise)
8-
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
30+
def funcify_Blockwise(op: Blockwise, node, **kwargs):
31+
if isinstance(op.core_op, Conv1d):
32+
return blockwise_conv1d(op, node, **kwargs)
33+
34+
core_f = mlx_funcify(op.core_op)
35+
36+
def blockwise_f(*inputs):
37+
return blockwise_f(*inputs)
938
core_node = op._create_dummy_core_node(node.inputs)
39+
1040
core_f = mlx_funcify(op.core_op, core_node)
1141
blockwise_f = core_f
1242
for i in range(op.batch_ndim(node)):

0 commit comments

Comments
 (0)