@@ -23,17 +23,20 @@ def inner_f(x, kernel):
2323
2424 b = np .broadcast_shapes (bx , bk )
2525
26- x = x . reshape ( b + (t ,))
27- kernel = kernel . reshape ( b + (h ,))
26+ x = mx . broadcast_to ( x , b + (t ,))
27+ kernel = mx . broadcast_to ( kernel , b + (h ,))
2828
2929 x_reshaped = x .reshape (- 1 , t ).T # shape equals to (N, B) -> N Time as batches all together
3030 kernel_squeeze = kernel .reshape (- 1 , h )
3131 b_prod = kernel_squeeze .shape [0 ]
3232
33- kernel_reshaped = mx .broadcast_to (a = kernel_squeeze [None , :, None ], shape = (b_prod , h , b_prod ))
33+ print (kernel_squeeze .shape )
34+
35+ print (b_prod , h , b_prod )
36+ kernel_reshaped = mx .broadcast_to (kernel_squeeze [:, :, None ], shape = (b_prod , h , b_prod ))
3437 conv_result = mx .conv1d (x_reshaped [None , :, :], kernel_reshaped , stride = 1 , padding = 0 , dilation = 1 )
3538 _ , conv_shape , _ = conv_result .shape
36- return mx .moveaxis (a = conv_result , source = - 1 , destination = 0 ).reshape (b + (conv_shape ,))
39+ mx .moveaxis (conv_result , source = - 1 , destination = 0 ).reshape (b + (conv_shape ,))
3740 return inner_f
3841
3942@mlx_funcify .register (Blockwise )
0 commit comments