diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index b710f7d4e57..c412ce7aa61 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -78,6 +78,76 @@ def get_conv2d_output_size( return torch.Size((in_size[0], out_channels, hout, wout)) +# Get the output size of a transposed 1D convolution given the input size and parameters +def get_conv_transpose1d_output_size( + in_size: torch.Size, + kernel_size: List[int], + out_channels: int, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + output_padding: Tuple[int], + channel_last: bool = False, +) -> torch.Size: + assert len(in_size) == 3 + if channel_last: + N, L, C = in_size + else: + N, C, L = in_size + + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + lout = ( + (L - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + + 1 + ) + + if channel_last: + return torch.Size((in_size[0], lout, out_channels)) + else: + return torch.Size((in_size[0], out_channels, lout)) + + +def get_conv_transpose2d_output_size( + in_size: torch.Size, + kernel_size: List[int], + out_channels: int, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + output_padding: Tuple[int], + channel_last: bool = False, +) -> torch.Size: + assert len(in_size) == 4 + if channel_last: + N, H, W, C = in_size + else: + N, C, H, W = in_size + + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + hout = ( + (H - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + + 1 + ) + wout = ( + (W - 1) * stride[1] + - 2 * padding[1] + + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + + 1 + ) + + if channel_last: + return torch.Size((in_size[0], hout, wout, out_channels)) + else: + return torch.Size((in_size[0], out_channels, hout, wout)) + + # Return the overload packet for the edge op def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: edge_op_namespace, edge_op_name = (