Skip to content

Commit 6b27dc4

Browse files
committed
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
1 parent 4e4923f commit 6b27dc4

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,29 @@ def blockwise_conv1d(op, node, **kwargs):
1818
# raise NotImplementedError("Only 1D batches are supported for conv1d")
1919

2020
def inner_f(x, kernel):
21-
*bx, t = x.shape
22-
*bk, h = kernel.shape
21+
# 1) Validate shapes
22+
B, T = x.shape
23+
Bk, K = kernel.shape
24+
if B != Bk:
25+
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
2326

24-
b = np.broadcast_shapes(bx, bk)
27+
# 2) Reshape x so that 'channels' = B, batch size = 1
28+
# → input shape (N=1, H=T, C_in=B)
29+
x_in = x.T[None, :, :] # shape (1, T, B)
2530

26-
x = mx.broadcast_to(x, b + (t,))
27-
kernel = mx.broadcast_to(kernel, b + (h,))
31+
# 3) Build weight array of shape (C_out=B, H_f=K, C_in=1)
32+
# groups = B will slice C_in into B single-channel groups
33+
w = kernel[:, :, None] # shape (B, K, 1)
2834

29-
x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together
30-
kernel_squeeze = kernel.reshape(-1, h)
31-
b_prod = kernel_squeeze.shape[0]
35+
# 4) Convolve with one group per sequence
36+
y = mx.conv1d(x_in, w,
37+
stride=1,
38+
padding=0,
39+
dilation=1,
40+
groups=B)
3241

33-
print(kernel_squeeze.shape)
34-
35-
print(b_prod, h, b_prod)
36-
kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod))
37-
conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1)
38-
_, conv_shape, _ = conv_result.shape
39-
return mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,))
42+
# 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose
43+
return y[0].T # final shape (B, T - K + 1)
4044
return inner_f
4145

4246
@mlx_funcify.register(Blockwise)

0 commit comments

Comments
 (0)