diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 3a7bf3ee41b..fa7b7feb208 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -132,7 +132,11 @@ def quantized_conv_meta( out_shift: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: - out_channels, _in_channels, *kernel_size = weight.shape + if channel_last: + out_channels, *kernel_size, _ = weight.shape + else: + out_channels, _, *kernel_size = weight.shape + in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 @@ -141,7 +145,13 @@ def quantized_conv_meta( # Compute the output tensor size output_size = ( get_conv1d_output_size( - in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0] + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + channel_last, ) if len(in_size) == 3 else get_conv2d_output_size( diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index 9e32f3472da..12eb899d9d8 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -43,14 +43,20 @@ def get_conv1d_output_size( padding: int, dilation: int, kernel_size: int, + channel_last: bool, ) -> torch.Size: assert len(in_size) == 3 - N, C, L = in_size + if channel_last: + N, L, C = in_size + else: + N, C, L = in_size # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 - return torch.Size((in_size[0], out_channels, lout)) + if channel_last: + return torch.Size((N, lout, out_channels)) + return torch.Size((N, out_channels, lout)) # Get the output size of a 2D convolution given the input size and parameters @@ -76,7 +82,8 @@ def get_conv2d_output_size( wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[ 1 ] + 1 - + if channel_last: + return torch.Size((N, hout, wout, out_channels)) return torch.Size((in_size[0], out_channels, hout, wout))