Skip to content

Commit 8b6ddd3

Browse files
ivangarcia44Ivan Garcia
andauthored
Add support for transposed convolution negative input padding (#4096)
Currently when a transposed convolution is lowered from the torch dialect to the linalg dialect we get an insert_slide operation to create padding for the input tensor. For example: %inserted_slice = tensor.insert_slice %arg0 into %cast[0, 0, 2, %c-1] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x?x?xf32> The above works well for the case where the input padding is positive. For transposed convolution the input padding is defined with this formula: dilation * (kernel_size - 1) - padding (see https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) for details. Notice that if the input padding is above the left hand term, we get negative padding. For these cases PyTorch reduces the size of the input tensor. The torch to linalg lowering was not doing this, and therefore its value does not match what PyTorch gives (captured in e2e tests TransposedConv2dNegativePadding and TransposedConv3dNegativePadding). To fix this a tensor.extract_slice operation is added just before the insert_slice operation to reduce the input tensor size as PyTorch does. In the example above we get the code below whose result matches the numerical values of PyTorch. %extracted_slice = tensor.extract_slice %arg0[0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32> %inserted_slice = tensor.insert_slice %extracted_slice into %4[0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32> For each dimension with a negative padding, we add a positive offset (absolute value of negative padding) in the corresponding dimension for the extract_slice operation, and the dimension size is reduced by twice that amount (elements are lost in both sides of the dimension as specified in PyTorch). Then on the insert_slice the negative padding dimension has an offset of zero because the trimmed dimension fits exactly. For the case when padding is positive the existing behavior is kept. @rsuderman @vivekkhandelwal1 @zjgarvey @penguin-wwy @ubfx @sahas3 @Hanumanth04 @dixinzhou @rafaelubalmw --------- Co-authored-by: Ivan Garcia <[email protected]>
1 parent a83397a commit 8b6ddd3

File tree

4 files changed

+253
-41
lines changed

4 files changed

+253
-41
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 127 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,31 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
748748
};
749749
} // namespace
750750

751+
namespace {
752+
bool isValueNegative(mlir::Value value) {
753+
// Try to fold the operation to a constant
754+
mlir::Operation *definingOp = value.getDefiningOp();
755+
756+
if (!definingOp)
757+
return false;
758+
759+
// Attempt to fold the operation
760+
mlir::SmallVector<mlir::OpFoldResult, 1> results;
761+
if (failed(definingOp->fold(results)) || results.empty())
762+
return false;
763+
764+
// Check if the folded result is a constant
765+
if (auto attr = results.front().dyn_cast<mlir::Attribute>()) {
766+
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr)) {
767+
int64_t intValue = intAttr.getInt();
768+
return intValue < 0;
769+
}
770+
}
771+
772+
return false;
773+
}
774+
} // namespace
775+
751776
namespace {
752777
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
753778
public:
@@ -1008,8 +1033,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
10081033
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
10091034
Value c1 =
10101035
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
1011-
Value c2 =
1012-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));
10131036

10141037
// Transpose and flip weight
10151038
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
@@ -1060,45 +1083,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
10601083
})
10611084
.getResult(0);
10621085

1063-
// Calculate padded input size, allocate tensor
1064-
SmallVector<Value> outerSizes{inBatch, inChannels};
1065-
SmallVector<Value> innerSizes{inBatch, inChannels};
1066-
SmallVector<Value> offsets{c0, c0};
1067-
for (size_t i = 0; i < numSpatialDims; i++) {
1068-
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
1069-
innerSize = rewriter.create<arith::MulIOp>(
1070-
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
1071-
innerSize = rewriter.create<arith::AddIOp>(loc, innerSize, c1);
1072-
1073-
Value offset = rewriter.create<arith::SubIOp>(loc, weightDims[i], c1);
1074-
offset = rewriter.create<arith::MulIOp>(
1075-
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
1076-
offset = rewriter.create<arith::SubIOp>(
1077-
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
1078-
1079-
Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
1080-
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);
1081-
outerSize = rewriter.create<arith::AddIOp>(
1082-
loc, outerSize,
1083-
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
1084-
1085-
outerSizes.push_back(outerSize);
1086-
offsets.push_back(offset);
1087-
}
1088-
1089-
// Allocate padded input tensor
1090-
Value initTensor =
1091-
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
1092-
1093-
// Insert input into allocated tensor
1094-
SmallVector<Value> strideIndexValues{c1, c1};
1095-
for (auto stride : strideIntValues)
1096-
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
1097-
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);
1098-
1099-
paddedInput = rewriter.create<tensor::InsertSliceOp>(
1100-
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1101-
initTensor, offsets, insertSizes, strideIndexValues);
1086+
paddedInput = createTransposedInputPadding(
1087+
inBatch, inChannels, inDims, weightDims, paddingIntValues,
1088+
strideIntValues, dilationIntValues, outputPaddingIntValues, input,
1089+
inputDTy, pad, rewriter, loc, numSpatialDims, c0, c1);
11021090

11031091
// Calculate output dims
11041092
for (size_t i = 0; i < numSpatialDims; i++)
@@ -1482,9 +1470,107 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
14821470
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
14831471
return success();
14841472
}
1473+
1474+
static Value createTransposedInputPadding(
1475+
Value inBatch, Value inChannels, SmallVector<Value> &inDims,
1476+
SmallVector<Value> &weightDims, SmallVector<Value> &paddingIntValues,
1477+
SmallVector<Value> &strideIntValues,
1478+
SmallVector<Value> &dilationIntValues,
1479+
SmallVector<Value> &outputPaddingIntValues, Value input, Type inputDTy,
1480+
Value pad, PatternRewriter &rewriter, Location loc, size_t numSpatialDims,
1481+
Value c0, Value c1);
14851482
};
14861483
} // namespace
14871484

1485+
Value ConvertAtenConvolutionOp::createTransposedInputPadding(
1486+
Value inBatch, Value inChannels, SmallVector<Value> &inDims,
1487+
SmallVector<Value> &weightDims, SmallVector<Value> &paddingIntValues,
1488+
SmallVector<Value> &strideIntValues, SmallVector<Value> &dilationIntValues,
1489+
SmallVector<Value> &outputPaddingIntValues, Value input, Type inputDTy,
1490+
Value pad, PatternRewriter &rewriter, Location loc, size_t numSpatialDims,
1491+
Value c0, Value c1) {
1492+
// Calculate padded input size, allocate tensor
1493+
SmallVector<Value> outerSizes{inBatch, inChannels};
1494+
SmallVector<Value> innerSizes{inBatch, inChannels};
1495+
SmallVector<Value> insertSliceOffsets{c0, c0};
1496+
1497+
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
1498+
SmallVector<Value> sliceSizes{inputSizes[0], inputSizes[1]};
1499+
1500+
// For the case in which the padding dimension value is negative,
1501+
// we will need to shrink the dimension. Note in the PyTorch
1502+
// ConvTranspose2d operator documentation that the padding is
1503+
// defined by dilation * (kernel_size - 1) - padding. If the
1504+
// resulting padding is negative, PyTorch will extract elements
1505+
// from both sides of the dimension.
1506+
SmallVector<Value> extractSliceOffsets{c0, c0};
1507+
bool anyDimensionPaddingIsNegative = false;
1508+
1509+
Value c2 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));
1510+
1511+
for (size_t i = 0; i < numSpatialDims; i++) {
1512+
Value innerSize = rewriter.createOrFold<arith::SubIOp>(loc, inDims[i], c1);
1513+
innerSize = rewriter.createOrFold<arith::MulIOp>(
1514+
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
1515+
innerSize = rewriter.createOrFold<arith::AddIOp>(loc, innerSize, c1);
1516+
1517+
Value offset = rewriter.createOrFold<arith::SubIOp>(loc, weightDims[i], c1);
1518+
offset = rewriter.createOrFold<arith::MulIOp>(
1519+
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
1520+
offset = rewriter.createOrFold<arith::SubIOp>(
1521+
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
1522+
1523+
Value outerSize = rewriter.createOrFold<arith::MulIOp>(loc, offset, c2);
1524+
outerSize = rewriter.createOrFold<arith::AddIOp>(loc, outerSize, innerSize);
1525+
outerSize = rewriter.createOrFold<arith::AddIOp>(
1526+
loc, outerSize,
1527+
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
1528+
1529+
outerSizes.push_back(outerSize);
1530+
if (isValueNegative(offset)) {
1531+
// Make the negative value positive by multiplying by -1.
1532+
anyDimensionPaddingIsNegative = true;
1533+
auto offsetType = offset.getType();
1534+
auto negOneConst = rewriter.createOrFold<arith::ConstantOp>(
1535+
loc, offsetType, rewriter.getIntegerAttr(offsetType, -1));
1536+
auto posOffset =
1537+
rewriter.createOrFold<arith::MulIOp>(loc, offset, negOneConst);
1538+
1539+
// Compute the reduced dimension size due to negative padding.
1540+
auto sizeReduction =
1541+
rewriter.createOrFold<arith::MulIOp>(loc, posOffset, c2);
1542+
sliceSizes.push_back(rewriter.createOrFold<arith::SubIOp>(
1543+
loc, inputSizes[i + 2], sizeReduction));
1544+
1545+
extractSliceOffsets.push_back(posOffset);
1546+
insertSliceOffsets.push_back(c0);
1547+
} else {
1548+
sliceSizes.push_back(inputSizes[i + 2]);
1549+
extractSliceOffsets.push_back(c0);
1550+
insertSliceOffsets.push_back(offset);
1551+
}
1552+
}
1553+
Value initTensor = createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
1554+
1555+
// Insert input into allocated tensor
1556+
SmallVector<Value> strideIndexValues{c1, c1};
1557+
for (auto stride : strideIntValues)
1558+
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
1559+
1560+
auto insertSliceOpInput = input;
1561+
if (anyDimensionPaddingIsNegative) {
1562+
insertSliceOpInput = rewriter.create<tensor::ExtractSliceOp>(
1563+
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1564+
extractSliceOffsets, sliceSizes, strideIndexValues);
1565+
}
1566+
1567+
auto paddedInput = rewriter.create<tensor::InsertSliceOp>(
1568+
loc,
1569+
torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput),
1570+
initTensor, insertSliceOffsets, sliceSizes, strideIndexValues);
1571+
return paddedInput;
1572+
}
1573+
14881574
namespace {
14891575

14901576
/// Creates coefficients based on DFT definition, see

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3770,6 +3770,9 @@
37703770
"TorchPrimLoopWhileLikeModule_basic",
37713771
"TraceModule_empty",
37723772
"TraceUnsignedIntModule_empty",
3773+
"TransposedConv1dNegativePadding_basic",
3774+
"TransposedConv2dNegativePadding_basic",
3775+
"TransposedConv3dNegativePadding_basic",
37733776
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
37743777
"UpSampleNearest2dBackwardScalesNone_basic",
37753778
"UpSampleNearest2dBackward_basic",
@@ -4758,6 +4761,9 @@
47584761
"TraceSignedIntModule_basic",
47594762
"TraceUnsignedIntModule_basic",
47604763
"TraceUnsignedIntModule_empty",
4764+
"TransposedConv1dNegativePadding_basic",
4765+
"TransposedConv2dNegativePadding_basic",
4766+
"TransposedConv3dNegativePadding_basic",
47614767
"TupleModule_basic",
47624768
"TypeAsDifferentModule_basic",
47634769
"TypeConversionF32ToF64Module_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,3 +1757,99 @@ def forward(self, inputVec, weight, bias):
17571757
@register_test_case(module_factory=lambda: ConvolutionModule2DGroupedTranspose())
17581758
def ConvolutionModule2DGroupedTranspose_basic(module, tu: TestUtils):
17591759
module.forward(tu.rand(1, 2, 5, 7), tu.rand(2, 2, 3, 3), tu.rand(4))
1760+
1761+
1762+
class TransposedConv1dNegativePadding(torch.nn.Module):
1763+
def __init__(self):
1764+
super().__init__()
1765+
1766+
@export
1767+
@annotate_args(
1768+
[
1769+
None,
1770+
([1, 1, 7], torch.float32, True),
1771+
([1, 2, 3], torch.float32, True),
1772+
([2], torch.float32, True),
1773+
]
1774+
)
1775+
def forward(self, inputVec, weight, bias):
1776+
return torch.ops.aten.convolution(
1777+
inputVec,
1778+
weight,
1779+
bias=bias,
1780+
stride=[1],
1781+
padding=[3],
1782+
dilation=[1],
1783+
transposed=True,
1784+
output_padding=[0],
1785+
groups=1,
1786+
)
1787+
1788+
1789+
@register_test_case(module_factory=lambda: TransposedConv1dNegativePadding())
1790+
def TransposedConv1dNegativePadding_basic(module, tu: TestUtils):
1791+
module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2))
1792+
1793+
1794+
class TransposedConv2dNegativePadding(torch.nn.Module):
1795+
def __init__(self):
1796+
super().__init__()
1797+
1798+
@export
1799+
@annotate_args(
1800+
[
1801+
None,
1802+
([1, 1, 4, 7], torch.float32, True),
1803+
([1, 2, 3, 3], torch.float32, True),
1804+
([2], torch.float32, True),
1805+
]
1806+
)
1807+
def forward(self, inputVec, weight, bias):
1808+
return torch.ops.aten.convolution(
1809+
inputVec,
1810+
weight,
1811+
bias=bias,
1812+
stride=[1, 1],
1813+
padding=[0, 3],
1814+
dilation=[1, 1],
1815+
transposed=True,
1816+
output_padding=[0, 0],
1817+
groups=1,
1818+
)
1819+
1820+
1821+
@register_test_case(module_factory=lambda: TransposedConv2dNegativePadding())
1822+
def TransposedConv2dNegativePadding_basic(module, tu: TestUtils):
1823+
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
1824+
1825+
1826+
class TransposedConv3dNegativePadding(torch.nn.Module):
1827+
def __init__(self):
1828+
super().__init__()
1829+
1830+
@export
1831+
@annotate_args(
1832+
[
1833+
None,
1834+
([4, 1, 8, 13, 17], torch.float32, True),
1835+
([1, 1, 3, 7, 3], torch.float32, True),
1836+
([1], torch.float32, True),
1837+
]
1838+
)
1839+
def forward(self, inputVec, weight, bias):
1840+
return torch.ops.aten.convolution(
1841+
inputVec,
1842+
weight,
1843+
bias=bias,
1844+
stride=[1, 1, 1],
1845+
padding=[2, 1, 3],
1846+
dilation=[1, 1, 1],
1847+
transposed=True,
1848+
output_padding=[0, 0, 0],
1849+
groups=1,
1850+
)
1851+
1852+
1853+
@register_test_case(module_factory=lambda: TransposedConv3dNegativePadding())
1854+
def TransposedConv3dNegativePadding_basic(module, tu: TestUtils):
1855+
module.forward(tu.rand(4, 1, 8, 13, 17), tu.rand(1, 1, 3, 7, 3), tu.rand(1))

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,27 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>)
150150
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int2 : !torch.vtensor<[1,2,5,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,10,14],f32>
151151
return %6 : !torch.vtensor<[1,4,10,14],f32>
152152
}
153+
154+
// CHECK-LABEL: func.func @tranConv2dNegativePadding(
155+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32>
156+
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
157+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
158+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32>
159+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
160+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
161+
func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> !torch.vtensor<[1, 2, 6, 3],f32> attributes {torch.assume_strict_symbolic_shapes} {
162+
%int0 = torch.constant.int 0
163+
%true = torch.constant.bool true
164+
%int1 = torch.constant.int 1
165+
%int2 = torch.constant.int 2
166+
%int3 = torch.constant.int 3
167+
%int4 = torch.constant.int 4
168+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_1_2_3_3_torch.float32> : tensor<1x2x3x3xf32>) : !torch.vtensor<[1,2,3,3],f32>
169+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_2_torch.float32> : tensor<2xf32>) : !torch.vtensor<[2],f32>
170+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
171+
%3 = torch.prim.ListConstruct %int0, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
172+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
173+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
174+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1, 1, 4, 7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1, 2, 6, 3],f32>
175+
return %6 : !torch.vtensor<[1, 2, 6, 3],f32>
176+
}

0 commit comments

Comments
 (0)