Skip to content

Commit e308f83

Browse files
committed
AI RULES BABY MY MATE
1 parent 6b27dc4 commit e308f83

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,30 @@ def blockwise_conv1d(op, node, **kwargs):
1818
# raise NotImplementedError("Only 1D batches are supported for conv1d")
1919

2020
def inner_f(x, kernel):
21-
# 1) Validate shapes
2221
B, T = x.shape
2322
Bk, K = kernel.shape
2423
if B != Bk:
2524
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}")
2625

27-
# 2) Reshape x so that 'channels' = B, batch size = 1
28-
# → input shape (N=1, H=T, C_in=B)
29-
x_in = x.T[None, :, :] # shape (1, T, B)
26+
# 1) Flip each kernel for true convolution
27+
kernels_flipped = kernel[:, ::-1] # shape (B, K)
3028

31-
# 3) Build weight array of shape (C_out=B, H_f=K, C_in=1)
32-
# groups = B will slice C_in into B single-channel groups
33-
w = kernel[:, :, None] # shape (B, K, 1)
29+
# 2) Reshape input into (N=1, H=T, C_in=B)
30+
x_in = x.T[None, :, :]
3431

35-
# 4) Convolve with one group per sequence
36-
y = mx.conv1d(x_in, w,
37-
stride=1,
38-
padding=0,
39-
dilation=1,
40-
groups=B)
32+
# 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1)
33+
w = kernels_flipped[:, :, None]
4134

42-
# 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose
43-
return y[0].T # final shape (B, T - K + 1)
35+
# 4) Convolve with one group per channel → valid mode
36+
y = mx.conv1d(
37+
x_in, w,
38+
stride=1,
39+
padding=0,
40+
dilation=1,
41+
groups=B
42+
)
43+
# y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1)
44+
return y[0].T
4445
return inner_f
4546

4647
@mlx_funcify.register(Blockwise)

0 commit comments

Comments
 (0)