@@ -18,29 +18,30 @@ def blockwise_conv1d(op, node, **kwargs):
1818 # raise NotImplementedError("Only 1D batches are supported for conv1d")
1919
2020 def inner_f (x , kernel ):
21- # 1) Validate shapes
2221 B , T = x .shape
2322 Bk , K = kernel .shape
2423 if B != Bk :
2524 raise ValueError (f"Batch mismatch: x has { B } , kernels has { Bk } " )
2625
27- # 2) Reshape x so that 'channels' = B, batch size = 1
28- # → input shape (N=1, H=T, C_in=B)
29- x_in = x .T [None , :, :] # shape (1, T, B)
26+ # 1) Flip each kernel for true convolution
27+ kernels_flipped = kernel [:, ::- 1 ] # shape (B, K)
3028
31- # 3) Build weight array of shape (C_out=B, H_f=K, C_in=1)
32- # groups = B will slice C_in into B single-channel groups
33- w = kernel [:, :, None ] # shape (B, K, 1)
29+ # 2) Reshape input into (N=1, H=T, C_in=B)
30+ x_in = x .T [None , :, :]
3431
35- # 4) Convolve with one group per sequence
36- y = mx .conv1d (x_in , w ,
37- stride = 1 ,
38- padding = 0 ,
39- dilation = 1 ,
40- groups = B )
32+ # 3) Build weight tensor of shape (C_out=B, H_f=K, C_in=1)
33+ w = kernels_flipped [:, :, None ]
4134
42- # 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose
43- return y [0 ].T # final shape (B, T - K + 1)
35+ # 4) Convolve with one group per channel → valid mode
36+ y = mx .conv1d (
37+ x_in , w ,
38+ stride = 1 ,
39+ padding = 0 ,
40+ dilation = 1 ,
41+ groups = B
42+ )
43+ # y: (1, T-K+1, B) → drop batch and transpose to (B, T-K+1)
44+ return y [0 ].T
4445 return inner_f
4546
4647@mlx_funcify .register (Blockwise )
0 commit comments