Skip to content

Commit 9f31ab1

Browse files
committed
Guys, I'm getting sad. We need help yisus!!!!!
1 parent 9d3eca8 commit 9f31ab1

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,36 @@
44
from pytensor.tensor.blockwise import Blockwise
55
from pytensor.tensor.signal.conv import Conv1d
66

7-
def blockwise_conv1d(op, node):
7+
import numpy as np
8+
9+
def blockwise_conv1d(op, node, **kwargs):
810
if op.core_op.mode != "valid":
911
raise NotImplementedError("Only 'valid' mode is supported for conv1d")
10-
batches_ndim = op.batch_ndim(node)
11-
if batches_ndim != 1:
12-
raise NotImplementedError("Only 1D batches are supported for conv1d")
12+
# batches_ndim = op.batch_ndim(node)
13+
# if batches_ndim != 1:
14+
# raise NotImplementedError("Only 1D batches are supported for conv1d")
1315

14-
_, kernel = node.inputs
15-
if not all(kernel.type.broadcastable[:batches_ndim]):
16-
raise NotImplementedError("Only 1D batches are supported for conv1d")
16+
# _, kernel = node.inputs
17+
# if not all(kernel.type.broadcastable[:batches_ndim]):
18+
# raise NotImplementedError("Only 1D batches are supported for conv1d")
1719

1820
def inner_f(x, kernel):
19-
x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together
20-
b = x_reshaped.shape[1] #
21-
kernel_squeeze = kernel.reshape(-1)
22-
f = kernel_squeeze.shape[0] # Number of filters
23-
kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b))
21+
*bx, t = x.shape
22+
*bk, h = kernel.shape
23+
24+
b = np.broadcast_shapes(bx, bk)
25+
26+
x = x.reshape(b + (t,))
27+
kernel = kernel.reshape(b + (h,))
28+
29+
x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together
30+
kernel_squeeze = kernel.reshape(-1, h)
31+
b_prod = kernel_squeeze.shape[0]
32+
33+
kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod))
2434
conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1)
2535
_, conv_shape, _ = conv_result.shape
26-
return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,))
36+
return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,))
2737
return inner_f
2838

2939
@mlx_funcify.register(Blockwise)

0 commit comments

Comments
 (0)