Skip to content

Commit 02327af

Browse files
authored
Adds onnx ConvTranspose support for autopadding. (#3797)
Adds onnx ConvTranspose support for autopadding (nod-ai/SHARK-ModelDev#839). - Adds support for attribute auto_pad="SAME_UPPER" or "SAME_LOWER" which will automatically calculate padding of input based on output shape. - Adds support, during auto-padding, for output_shape=[H,W] which overrides the default output shape of input_shape[i]*stride[i] (for spatial dimensions only). - Adds lit test for auto-padding. - Tests are added by nod-ai/SHARK-TestSuite#370 NOTE: ConvTranspose still doesn't support asymmetric padding, therefore multiple original onnx tests still won't pass.
1 parent 9c70676 commit 02327af

File tree

2 files changed

+131
-22
lines changed

2 files changed

+131
-22
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,20 +1690,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
16901690
std::string autoPad;
16911691
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
16921692
return failure();
1693-
if (autoPad != "NOTSET") {
1694-
// TODO: Add support for `auto_pad` != "NOTSET"
1695-
return rewriter.notifyMatchFailure(
1696-
binder.op, "unsupported conversion: auto_pad != NOTSET");
1697-
}
1698-
SmallVector<int64_t> outputShape;
1699-
if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {}))
1700-
return failure();
1701-
if (outputShape.size()) {
1702-
// TODO: Add support for non-None output_shape value.
1703-
return rewriter.notifyMatchFailure(
1704-
binder.op,
1705-
"unsupported conversion: output_shape should be absent");
1706-
}
17071693
Torch::ValueTensorType resultType;
17081694
Value input, weight;
17091695
int64_t group;
@@ -1737,6 +1723,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
17371723
}
17381724
}
17391725
}
1726+
} else {
1727+
for (unsigned i = 0; i < weightShape.size() - 2; i++) {
1728+
kernelShape.push_back(weightShape[i + 2]);
1729+
}
17401730
}
17411731

17421732
// Determine the rank of input tensor.
@@ -1746,7 +1736,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
17461736
"Unimplemented: unranked tensor");
17471737
unsigned rank = *maybeRank;
17481738

1749-
SmallVector<int64_t> padding, strides, dilations, outputPadding;
1739+
SmallVector<int64_t> padding, strides, dilations, outputPadding,
1740+
outputShape;
17501741
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations,
17511742
defaultOutputPadding;
17521743
for (unsigned i = 0; i < rank - 2; i++) {
@@ -1762,13 +1753,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
17621753
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
17631754
// at the beginning of axis i and xi_end, the number of pixels added at
17641755
// the end of axis i.
1765-
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
1766-
return failure();
1767-
}
1768-
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
1769-
return rewriter.notifyMatchFailure(
1770-
binder.op, "padding list size does not match the number of axes");
1771-
}
17721756
if (binder.s64IntegerArrayAttr(dilations, "dilations",
17731757
defaultDilations)) {
17741758
return failure();
@@ -1794,7 +1778,60 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
17941778
binder.op,
17951779
"output_padding list size does not match the number of axes");
17961780
}
1781+
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
1782+
if (!inputTensorType || !inputTensorType.hasSizes()) {
1783+
return rewriter.notifyMatchFailure(
1784+
binder.op, "Expected input type having sizes");
1785+
}
1786+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
17971787

1788+
if (autoPad == "VALID") {
1789+
// Zero padding.
1790+
padding = defaultPadding;
1791+
} else if (autoPad == "NOTSET") {
1792+
// Explicit padding; read pads with defaults.
1793+
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding))
1794+
return failure();
1795+
} else { // autopad == SAME_UPPER or SAME_LOWER
1796+
// Auto-padding; output_shape defaults to input_shape * strides.
1797+
SmallVector<int64_t> defaultOutputShape;
1798+
for (unsigned i = 0; i < rank - 2; i++) {
1799+
defaultOutputShape.push_back(inputShape[2 + i] * strides[i]);
1800+
}
1801+
if (binder.s64IntegerArrayAttr(outputShape, "output_shape",
1802+
defaultOutputShape))
1803+
return failure();
1804+
SmallVector<int64_t> paddingEnd;
1805+
for (unsigned i = 0; i < rank - 2; i++) {
1806+
int64_t totalPadding =
1807+
strides[i] * (inputShape[2 + i] - 1) + outputPadding[i] +
1808+
((kernelShape[i] - 1) * dilations[i] + 1) - outputShape[i];
1809+
if (totalPadding % 2) {
1810+
// TODO: Add support for different padding values for the
1811+
// beginning and ending along each spatial axis.
1812+
return rewriter.notifyMatchFailure(
1813+
binder.op,
1814+
"unsupported conversion: the combination of stride, "
1815+
"input_shape, kernel_shape, dilation, output_padding and "
1816+
"output_shape caused auto-padding to produce asymmetric "
1817+
"padding which isn't currently supported.");
1818+
}
1819+
int64_t half = totalPadding / 2;
1820+
int64_t remainder = totalPadding - half;
1821+
if (autoPad == "SAME_UPPER") {
1822+
padding.push_back(half);
1823+
paddingEnd.push_back(remainder);
1824+
} else {
1825+
padding.push_back(remainder);
1826+
paddingEnd.push_back(half);
1827+
}
1828+
}
1829+
padding.insert(padding.end(), paddingEnd.begin(), paddingEnd.end());
1830+
}
1831+
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
1832+
return rewriter.notifyMatchFailure(
1833+
binder.op, "padding list size does not match the number of axes");
1834+
}
17981835
SmallVector<Value> cstPadding, cstStrides, cstDilations,
17991836
cstOutputPadding;
18001837
if (padding.size() != 2 * (rank - 2)) {

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,78 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc
13291329

13301330
// -----
13311331

1332+
// CHECK-LABEL: @test_convtranspose_autopad_same_upper
1333+
func.func @test_convtranspose_autopad_same_upper(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} {
1334+
// CHECK: %[[C1:.*]] = torch.constant.int 1
1335+
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
1336+
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
1337+
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
1338+
// CHECK: %[[C2:.*]] = torch.constant.int 2
1339+
// CHECK: %[[C2_3:.*]] = torch.constant.int 2
1340+
// CHECK: %[[C0:.*]] = torch.constant.int 0
1341+
// CHECK: %[[C0_4:.*]] = torch.constant.int 0
1342+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
1343+
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
1344+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list<int>
1345+
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1346+
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
1347+
// CHECK: %[[BIAS:.*]] = torch.constant.none
1348+
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
1349+
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,6,6],f32>
1350+
%4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_UPPER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32>
1351+
return %4 : !torch.vtensor<[1,2,6,6],f32>
1352+
}
1353+
1354+
// -----
1355+
1356+
// CHECK-LABEL: @test_convtranspose_autopad_same_lower
1357+
func.func @test_convtranspose_autopad_same_lower(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} {
1358+
// CHECK: %[[C1:.*]] = torch.constant.int 1
1359+
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
1360+
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
1361+
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
1362+
// CHECK: %[[C2:.*]] = torch.constant.int 2
1363+
// CHECK: %[[C2_3:.*]] = torch.constant.int 2
1364+
// CHECK: %[[C0:.*]] = torch.constant.int 0
1365+
// CHECK: %[[C0_4:.*]] = torch.constant.int 0
1366+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
1367+
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
1368+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list<int>
1369+
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1370+
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
1371+
// CHECK: %[[BIAS:.*]] = torch.constant.none
1372+
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
1373+
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,6,6],f32>
1374+
%4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_LOWER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32>
1375+
return %4 : !torch.vtensor<[1,2,6,6],f32>
1376+
}
1377+
1378+
// -----
1379+
1380+
// CHECK-LABEL: @test_convtranspose_autopad_valid
1381+
func.func @test_convtranspose_autopad_valid(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} {
1382+
// CHECK: %[[C0:.*]] = torch.constant.int 0
1383+
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
1384+
// CHECK: %[[C1:.*]] = torch.constant.int 1
1385+
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
1386+
// CHECK: %[[C2:.*]] = torch.constant.int 2
1387+
// CHECK: %[[C2_2:.*]] = torch.constant.int 2
1388+
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
1389+
// CHECK: %[[C0_4:.*]] = torch.constant.int 0
1390+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
1391+
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
1392+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
1393+
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_3]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1394+
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
1395+
// CHECK: %[[BIAS:.*]] = torch.constant.none
1396+
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
1397+
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,8,8],f32>
1398+
%4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="VALID", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32>
1399+
return %4 : !torch.vtensor<[1,2,8,8],f32>
1400+
}
1401+
1402+
// -----
1403+
13321404
// CHECK-LABEL: @test_batchnorm_epsilon
13331405
func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
13341406
// CHECK: %[[FALSE:.*]] = torch.constant.bool false

0 commit comments

Comments
 (0)