Skip to content

Commit ae310b4

Browse files
[ONNX] Handle dynamic shaped inputs for conv (#4066)
Handle dynamic shaped inputs for "onnx.Conv" operator. This is required for some of the onnx zoo models which have dynamic shaped inputs for conv. Currently, only static inputs are being handled for the same. --------- Co-authored-by: Praveen G <[email protected]>
1 parent 87e1f76 commit ae310b4

File tree

3 files changed

+185
-124
lines changed

3 files changed

+185
-124
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 118 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13271327

13281328
SmallVector<int64_t> padding, strides, dilations;
13291329
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations;
1330+
SmallVector<Value> paddingValues;
1331+
13301332
for (unsigned i = 0; i < rank - 2; i++) {
13311333
defaultPadding.push_back(0);
13321334
defaultStrides.push_back(1);
@@ -1360,36 +1362,88 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13601362
// at the beginning of axis i and xi_end, the number of pixels added at
13611363
// the end of axis i.
13621364
if (autoPad == "NOTSET") {
1363-
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
1365+
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding))
13641366
return failure();
1365-
}
1367+
1368+
// Use the padding values
1369+
for (int64_t pad : padding)
1370+
paddingValues.push_back(rewriter.create<Torch::ConstantIntOp>(
1371+
loc, rewriter.getI64IntegerAttr(pad)));
13661372
} else if (autoPad == "VALID") {
1367-
padding = defaultPadding;
1373+
for (int64_t pad : defaultPadding)
1374+
paddingValues.push_back(rewriter.create<Torch::ConstantIntOp>(
1375+
loc, rewriter.getI64IntegerAttr(pad)));
13681376
} else {
13691377
const bool isSameLower = autoPad == "SAME_LOWER";
13701378
const unsigned spatialRank = rank - 2;
1371-
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
1372-
padding.resize_for_overwrite(2 * spatialRank);
1379+
paddingValues.resize_for_overwrite(2 * spatialRank);
1380+
13731381
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
1374-
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
1375-
inputShape[dimIdx + 2] == Torch::kUnknownSize)
1376-
return rewriter.notifyMatchFailure(
1377-
binder.op,
1378-
"expected weight and input tensor to have static shape");
1379-
const int64_t dilatedKernelSize =
1380-
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
1381-
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
1382-
strides[dimIdx] -
1383-
1) *
1384-
strides[dimIdx] +
1385-
dilatedKernelSize - inputShape[dimIdx + 2];
1386-
totalPad = totalPad >= 0 ? totalPad : 0;
1387-
padding[dimIdx] =
1388-
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1389-
padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
1382+
// dilatedSize = dilations[dimIdx]*(weightShape[dimIdx + 2] - 1) + 1
1383+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
1384+
loc, rewriter.getI64IntegerAttr(1));
1385+
Value dilationValue = rewriter.create<Torch::ConstantIntOp>(
1386+
loc, rewriter.getI64IntegerAttr(dilations[dimIdx]));
1387+
Value weightDimSize =
1388+
Torch::getTensorDimSize(rewriter, weight, dimIdx + 2);
1389+
Value weightMinusOne = rewriter.create<Torch::AtenSubIntOp>(
1390+
loc, weightDimSize, cstOne);
1391+
Value dilationMulWeight = rewriter.create<Torch::AtenMulIntOp>(
1392+
loc, dilationValue, weightMinusOne);
1393+
Value dilatedKernelSize = rewriter.create<Torch::AtenAddIntOp>(
1394+
loc, dilationMulWeight, cstOne);
1395+
1396+
// totalPad = (((inputShape[dimIdx + 2] + strides[dimIdx] -1) /
1397+
// strides[dimIdx]) - 1) * strides[dimIdx] +
1398+
// dilatedKernelSize - inputShape[dimIdx + 2];
1399+
1400+
Value stridesValue = rewriter.create<Torch::ConstantIntOp>(
1401+
loc, rewriter.getI64IntegerAttr(strides[dimIdx]));
1402+
Value inputDimSize =
1403+
Torch::getTensorDimSize(rewriter, input, dimIdx + 2);
1404+
Value stridesMinusOne =
1405+
rewriter.create<Torch::AtenSubIntOp>(loc, stridesValue, cstOne);
1406+
Value inputStrides = rewriter.create<Torch::AtenAddIntOp>(
1407+
loc, inputDimSize, stridesMinusOne);
1408+
inputStrides = rewriter.create<Torch::AtenFloordivIntOp>(
1409+
loc, inputStrides, stridesValue);
1410+
inputStrides =
1411+
rewriter.create<Torch::AtenSubIntOp>(loc, inputStrides, cstOne);
1412+
inputStrides = rewriter.create<Torch::AtenMulIntOp>(
1413+
loc, inputStrides, stridesValue);
1414+
Value strideWithDilation = rewriter.create<Torch::AtenAddIntOp>(
1415+
loc, inputStrides, dilatedKernelSize);
1416+
Value totalPad = rewriter.create<Torch::AtenSubIntOp>(
1417+
loc, strideWithDilation, inputDimSize);
1418+
1419+
// totalPad = totalPad > 0 ? totalPad : 0;
1420+
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
1421+
loc, rewriter.getI64IntegerAttr(0));
1422+
totalPad =
1423+
rewriter.create<Torch::PrimMaxIntOp>(loc, totalPad, cstZero);
1424+
1425+
// padding[dimIdx] =
1426+
// isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1427+
// padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
1428+
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
1429+
loc, rewriter.getI64IntegerAttr(2));
1430+
if (isSameLower) {
1431+
auto padPlusOne =
1432+
rewriter.create<Torch::AtenAddIntOp>(loc, totalPad, cstOne);
1433+
paddingValues[dimIdx] = rewriter.create<Torch::AtenFloordivIntOp>(
1434+
loc, padPlusOne, cstTwo);
1435+
} else {
1436+
paddingValues[dimIdx] = rewriter.create<Torch::AtenFloordivIntOp>(
1437+
loc, totalPad, cstTwo);
1438+
}
1439+
paddingValues[spatialRank + dimIdx] =
1440+
rewriter.create<Torch::AtenSubIntOp>(loc, totalPad,
1441+
paddingValues[dimIdx]);
13901442
}
13911443
}
1392-
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
1444+
1445+
if (paddingValues.size() != rank - 2 &&
1446+
paddingValues.size() != 2 * (rank - 2)) {
13931447
return rewriter.notifyMatchFailure(
13941448
binder.op, "padding list size does not match the number of axes");
13951449
}
@@ -1398,11 +1452,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13981452
cstOutputPadding;
13991453
Value paddedInput = input;
14001454
Value paddingList;
1401-
if (padding.size() != 2 * (rank - 2)) {
1402-
for (int64_t i : padding) {
1403-
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
1404-
loc, rewriter.getI64IntegerAttr(i)));
1405-
}
1455+
1456+
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
1457+
loc, rewriter.getI64IntegerAttr(0));
1458+
1459+
if (paddingValues.size() != 2 * (rank - 2)) {
1460+
cstPadding = paddingValues;
14061461
paddingList = rewriter.create<Torch::PrimListConstructOp>(
14071462
loc,
14081463
Torch::ListType::get(
@@ -1418,17 +1473,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14181473
// rightmost dim start and end, then next to last, and so on, e.g. {l,
14191474
// r, t, b}.
14201475
bool matchedPads = true;
1421-
for (unsigned i = 0; i < padding.size() / 2; i++) {
1422-
if (padding[i] != padding[i + (padding.size() / 2)]) {
1476+
for (unsigned i = 0; i < paddingValues.size() / 2; i++) {
1477+
int64_t padLow, padHigh;
1478+
if (!matchPattern(paddingValues[i],
1479+
Torch::m_TorchConstantInt(&padLow)) ||
1480+
!matchPattern(paddingValues[i + (paddingValues.size() / 2)],
1481+
Torch::m_TorchConstantInt(&padHigh)) ||
1482+
padLow != padHigh) {
14231483
matchedPads = false;
14241484
break;
14251485
}
14261486
}
14271487
if (matchedPads) {
1428-
for (unsigned i = 0; i < padding.size() / 2; i++) {
1429-
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
1430-
loc, rewriter.getI64IntegerAttr(padding[i])));
1431-
}
1488+
for (unsigned i = 0; i < paddingValues.size() / 2; i++)
1489+
cstPadding.push_back(paddingValues[i]);
14321490
paddingList = rewriter.create<Torch::PrimListConstructOp>(
14331491
loc,
14341492
Torch::ListType::get(
@@ -1437,16 +1495,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14371495
} else {
14381496
SmallVector<Value> padsRearrange;
14391497
SmallVector<Value> inputPaddingList;
1440-
for (uint32_t i = 0; i < padding.size() / 2; i++) {
1441-
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
1442-
loc, rewriter.getI64IntegerAttr(
1443-
padding[padding.size() / 2 - i - 1])));
1444-
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
1445-
loc,
1446-
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
1447-
inputPaddingList.emplace_back(
1448-
rewriter.create<Torch::ConstantIntOp>(
1449-
loc, rewriter.getI64IntegerAttr(0)));
1498+
for (uint32_t i = 0; i < paddingValues.size() / 2; i++) {
1499+
padsRearrange.emplace_back(
1500+
paddingValues[paddingValues.size() / 2 - i - 1]);
1501+
padsRearrange.emplace_back(
1502+
(paddingValues[paddingValues.size() - i - 1]));
1503+
inputPaddingList.emplace_back(cstZero);
14501504
}
14511505
// The conv op itself will have no padding since the actual padding
14521506
// is performed using the torch.pad preceding it.
@@ -1468,23 +1522,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14681522
Value constantValue;
14691523

14701524
if (isa<IntegerType>(inputTensorType.getDtype()))
1471-
constantValue = rewriter.create<Torch::ConstantIntOp>(
1472-
loc, rewriter.getI64IntegerAttr(0));
1525+
constantValue = cstZero;
14731526
if (isa<FloatType>(inputTensorType.getDtype()))
14741527
constantValue = rewriter.create<Torch::ConstantFloatOp>(
14751528
loc, rewriter.getF64FloatAttr(0.0f));
1529+
1530+
auto getPadOutputSizeForInput = [&](int64_t low, int64_t high,
1531+
int64_t inputSize) {
1532+
int64_t padLow, padHigh;
1533+
if (inputSize == Torch::kUnknownSize ||
1534+
!matchPattern(paddingValues[low],
1535+
Torch::m_TorchConstantInt(&padLow)) ||
1536+
!matchPattern(paddingValues[high],
1537+
Torch::m_TorchConstantInt(&padHigh)))
1538+
return Torch::kUnknownSize;
1539+
return inputSize + padLow + padHigh;
1540+
};
1541+
14761542
// Pad output shape must be computed explicitly from the pad values
1543+
// for static dims
14771544
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
1478-
for (uint32_t i = 0; i < padding.size() / 2; i++) {
1479-
newInputShape[2 + i] +=
1480-
padding[i] + padding[(padding.size() / 2) + i];
1545+
for (uint32_t i = 0; i < paddingValues.size() / 2; i++) {
1546+
newInputShape[2 + i] = getPadOutputSizeForInput(
1547+
i, (paddingValues.size() / 2) + i, newInputShape[2 + i]);
14811548
}
1549+
14821550
auto padTy = rewriter.getType<Torch::ValueTensorType>(
14831551
newInputShape, inputTensorType.getDtype());
14841552
paddedInput = rewriter.create<Torch::AtenPadOp>(
14851553
loc, padTy, input, padsSizeList, modeVal, constantValue);
14861554
}
14871555
}
1556+
14881557
for (int64_t i : dilations) {
14891558
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
14901559
loc, rewriter.getI64IntegerAttr(i)));
@@ -1493,8 +1562,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14931562
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
14941563
loc, rewriter.getI64IntegerAttr(i)));
14951564
}
1496-
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
1497-
loc, rewriter.getI64IntegerAttr(0));
1565+
14981566
cstOutputPadding = {cstZero, cstZero};
14991567

15001568
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(

0 commit comments

Comments
 (0)