|
4 | 4 | from pytensor.tensor.blockwise import Blockwise |
5 | 5 | from pytensor.tensor.signal.conv import Conv1d |
6 | 6 |
|
7 | | -def blockwise_conv1d(op, node): |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +def blockwise_conv1d(op, node, **kwargs): |
8 | 10 | if op.core_op.mode != "valid": |
9 | 11 | raise NotImplementedError("Only 'valid' mode is supported for conv1d") |
10 | | - batches_ndim = op.batch_ndim(node) |
11 | | - if batches_ndim != 1: |
12 | | - raise NotImplementedError("Only 1D batches are supported for conv1d") |
| 12 | + # batches_ndim = op.batch_ndim(node) |
| 13 | + # if batches_ndim != 1: |
| 14 | + # raise NotImplementedError("Only 1D batches are supported for conv1d") |
13 | 15 |
|
14 | | - _, kernel = node.inputs |
15 | | - if not all(kernel.type.broadcastable[:batches_ndim]): |
16 | | - raise NotImplementedError("Only 1D batches are supported for conv1d") |
| 16 | + # _, kernel = node.inputs |
| 17 | + # if not all(kernel.type.broadcastable[:batches_ndim]): |
| 18 | + # raise NotImplementedError("Only 1D batches are supported for conv1d") |
17 | 19 |
|
18 | 20 | def inner_f(x, kernel): |
19 | | - x_reshaped = x.reshape(-1, x.shape[-1]).T # shape equals to (N, B) -> N Time as batches all together |
20 | | - b = x_reshaped.shape[1] # |
21 | | - kernel_squeeze = kernel.reshape(-1) |
22 | | - f = kernel_squeeze.shape[0] # Number of filters |
23 | | - kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b, f, b)) |
| 21 | + *bx, t = x.shape |
| 22 | + *bk, h = kernel.shape |
| 23 | + |
| 24 | + b = np.broadcast_shapes(bx, bk) |
| 25 | + |
| 26 | + x = x.reshape(b + (t,)) |
| 27 | + kernel = kernel.reshape(b + (h,)) |
| 28 | + |
| 29 | + x_reshaped = x.reshape(-1, t).T # shape equals to (N, B) -> N Time as batches all together |
| 30 | + kernel_squeeze = kernel.reshape(-1, h) |
| 31 | + b_prod = kernel_squeeze.shape[0] |
| 32 | + |
| 33 | + kernel_reshaped = mx.broadcast_to(a=kernel_squeeze[None, :, None], shape=(b_prod, h, b_prod)) |
24 | 34 | conv_result = mx.conv1d(x_reshaped[None, :, :], kernel_reshaped, stride=1, padding=0, dilation=1) |
25 | 35 | _, conv_shape, _ = conv_result.shape |
26 | | - return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(x.shape[:-1] + (conv_shape,)) |
| 36 | + return mx.moveaxis(a=conv_result, source=-1, destination=0).reshape(b + (conv_shape,)) |
27 | 37 | return inner_f |
28 | 38 |
|
29 | 39 | @mlx_funcify.register(Blockwise) |
|
0 commit comments