@@ -18,25 +18,29 @@ def blockwise_conv1d(op, node, **kwargs):
1818 # raise NotImplementedError("Only 1D batches are supported for conv1d")
1919
2020 def inner_f (x , kernel ):
21- * bx , t = x .shape
22- * bk , h = kernel .shape
21+ # 1) Validate shapes
22+ B , T = x .shape
23+ Bk , K = kernel .shape
24+ if B != Bk :
25+ raise ValueError (f"Batch mismatch: x has { B } , kernels has { Bk } " )
2326
24- b = np .broadcast_shapes (bx , bk )
27+ # 2) Reshape x so that 'channels' = B, batch size = 1
28+ # → input shape (N=1, H=T, C_in=B)
29+ x_in = x .T [None , :, :] # shape (1, T, B)
2530
26- x = mx .broadcast_to (x , b + (t ,))
27- kernel = mx .broadcast_to (kernel , b + (h ,))
31+ # 3) Build weight array of shape (C_out=B, H_f=K, C_in=1)
32+ # groups = B will slice C_in into B single-channel groups
33+ w = kernel [:, :, None ] # shape (B, K, 1)
2834
29- x_reshaped = x .reshape (- 1 , t ).T # shape equals to (N, B) -> N Time as batches all together
30- kernel_squeeze = kernel .reshape (- 1 , h )
31- b_prod = kernel_squeeze .shape [0 ]
35+ # 4) Convolve with one group per sequence
36+ y = mx .conv1d (x_in , w ,
37+ stride = 1 ,
38+ padding = 0 ,
39+ dilation = 1 ,
40+ groups = B )
3241
33- print (kernel_squeeze .shape )
34-
35- print (b_prod , h , b_prod )
36- kernel_reshaped = mx .broadcast_to (kernel_squeeze [:, :, None ], shape = (b_prod , h , b_prod ))
37- conv_result = mx .conv1d (x_reshaped [None , :, :], kernel_reshaped , stride = 1 , padding = 0 , dilation = 1 )
38- _ , conv_shape , _ = conv_result .shape
39- return mx .moveaxis (conv_result , source = - 1 , destination = 0 ).reshape (b + (conv_shape ,))
42+ # 5) y has shape (1, T - K + 1, B); drop the batch axis and transpose
43+ return y [0 ].T # final shape (B, T - K + 1)
4044 return inner_f
4145
4246@mlx_funcify .register (Blockwise )
0 commit comments