You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support for transposed convolution negative input padding (#4096)
Currently when a transposed convolution is lowered from the torch
dialect to the linalg dialect we get an insert_slide operation to create
padding for the input tensor. For example:
%inserted_slice = tensor.insert_slice %arg0 into %cast[0, 0, 2, %c-1]
[1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x?x?xf32>
The above works well for the case where the input padding is positive.
For transposed convolution the input padding is defined with this
formula: dilation * (kernel_size - 1) - padding (see
https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)
for details.
Notice that if the input padding is above the left hand term, we get
negative padding. For these cases PyTorch reduces the size of the input
tensor. The torch to linalg lowering was not doing this, and therefore
its value does not match what PyTorch gives (captured in e2e tests
TransposedConv2dNegativePadding and TransposedConv3dNegativePadding).
To fix this a tensor.extract_slice operation is added just before the
insert_slice operation to reduce the input tensor size as PyTorch does.
In the example above we get the code below whose result matches the
numerical values of PyTorch.
%extracted_slice = tensor.extract_slice %arg0[0, 0, 0, 1] [1, 1, 4, 5]
[1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
%inserted_slice = tensor.insert_slice %extracted_slice into %4[0, 0, 2,
0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into
tensor<1x1x8x5xf32>
For each dimension with a negative padding, we add a positive offset
(absolute value of negative padding) in the corresponding dimension for
the extract_slice operation, and the dimension size is reduced by twice
that amount (elements are lost in both sides of the dimension as
specified in PyTorch).
Then on the insert_slice the negative padding dimension has an offset of
zero because the trimmed dimension fits exactly. For the case when
padding is positive the existing behavior is kept.
@rsuderman@vivekkhandelwal1@zjgarvey@penguin-wwy@ubfx@sahas3@Hanumanth04@dixinzhou@rafaelubalmw
---------
Co-authored-by: Ivan Garcia <[email protected]>
0 commit comments