Skip to content

Commit 37440ff

Browse files
committed
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
1 parent 9f31ab1 commit 37440ff

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@ def inner_f(x, kernel):
2323

2424
b = np.broadcast_shapes(bx, bk)
2525

26-
x = x.reshape(b + (t,))
27-
kernel = kernel.reshape(b + (h,))
26+
x = mx.broadcast_to(x, b + (t,))
27+
kernel = mx.broadcast_to(kernel, b + (h,))
2828

2929
x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together
3030
kernel_squeeze = kernel.reshape(-1, h)
3131
b_prod = kernel_squeeze.shape[0]
3232

33-
kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod))
33+
print(kernel_squeeze.shape)
34+
35+
print(b_prod, h, b_prod)
36+
kernel_reshaped = mx.broadcast_to(kernel_squeeze[:, :, None], shape=(b_prod, h, b_prod))
3437
conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1)
3538
_, conv_shape, _ = conv_result.shape
36-
return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,))
39+
mx.moveaxis(conv_result, source=-1, destination=0).reshape(b + (conv_shape,))
3740
return inner_f
3841

3942
@mlx_funcify.register(Blockwise)

0 commit comments

Comments
 (0)