@@ -1327,6 +1327,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1327
1327
1328
1328
SmallVector<int64_t > padding, strides, dilations;
1329
1329
SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations;
1330
+ SmallVector<Value> paddingValues;
1331
+
1330
1332
for (unsigned i = 0 ; i < rank - 2 ; i++) {
1331
1333
defaultPadding.push_back (0 );
1332
1334
defaultStrides.push_back (1 );
@@ -1360,36 +1362,88 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1360
1362
// at the beginning of axis i and xi_end, the number of pixels added at
1361
1363
// the end of axis i.
1362
1364
if (autoPad == " NOTSET" ) {
1363
- if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
1365
+ if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding))
1364
1366
return failure ();
1365
- }
1367
+
1368
+ // Use the padding values
1369
+ for (int64_t pad : padding)
1370
+ paddingValues.push_back (rewriter.create <Torch::ConstantIntOp>(
1371
+ loc, rewriter.getI64IntegerAttr (pad)));
1366
1372
} else if (autoPad == " VALID" ) {
1367
- padding = defaultPadding;
1373
+ for (int64_t pad : defaultPadding)
1374
+ paddingValues.push_back (rewriter.create <Torch::ConstantIntOp>(
1375
+ loc, rewriter.getI64IntegerAttr (pad)));
1368
1376
} else {
1369
1377
const bool isSameLower = autoPad == " SAME_LOWER" ;
1370
1378
const unsigned spatialRank = rank - 2 ;
1371
- ArrayRef< int64_t > inputShape = inputTensorType. getSizes ( );
1372
- padding. resize_for_overwrite ( 2 * spatialRank);
1379
+ paddingValues. resize_for_overwrite ( 2 * spatialRank );
1380
+
1373
1381
for (unsigned dimIdx = 0 ; dimIdx < spatialRank; dimIdx++) {
1374
- if (weightShape[dimIdx + 2 ] == Torch::kUnknownSize ||
1375
- inputShape[dimIdx + 2 ] == Torch::kUnknownSize )
1376
- return rewriter.notifyMatchFailure (
1377
- binder.op ,
1378
- " expected weight and input tensor to have static shape" );
1379
- const int64_t dilatedKernelSize =
1380
- dilations[dimIdx] * (weightShape[dimIdx + 2 ] - 1 ) + 1 ;
1381
- int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
1382
- strides[dimIdx] -
1383
- 1 ) *
1384
- strides[dimIdx] +
1385
- dilatedKernelSize - inputShape[dimIdx + 2 ];
1386
- totalPad = totalPad >= 0 ? totalPad : 0 ;
1387
- padding[dimIdx] =
1388
- isSameLower ? ((totalPad + 1 ) / 2 ) : (totalPad / 2 );
1389
- padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
1382
+ // dilatedSize = dilations[dimIdx]*(weightShape[dimIdx + 2] - 1) + 1
1383
+ Value cstOne = rewriter.create <Torch::ConstantIntOp>(
1384
+ loc, rewriter.getI64IntegerAttr (1 ));
1385
+ Value dilationValue = rewriter.create <Torch::ConstantIntOp>(
1386
+ loc, rewriter.getI64IntegerAttr (dilations[dimIdx]));
1387
+ Value weightDimSize =
1388
+ Torch::getTensorDimSize (rewriter, weight, dimIdx + 2 );
1389
+ Value weightMinusOne = rewriter.create <Torch::AtenSubIntOp>(
1390
+ loc, weightDimSize, cstOne);
1391
+ Value dilationMulWeight = rewriter.create <Torch::AtenMulIntOp>(
1392
+ loc, dilationValue, weightMinusOne);
1393
+ Value dilatedKernelSize = rewriter.create <Torch::AtenAddIntOp>(
1394
+ loc, dilationMulWeight, cstOne);
1395
+
1396
+ // totalPad = (((inputShape[dimIdx + 2] + strides[dimIdx] -1) /
1397
+ // strides[dimIdx]) - 1) * strides[dimIdx] +
1398
+ // dilatedKernelSize - inputShape[dimIdx + 2];
1399
+
1400
+ Value stridesValue = rewriter.create <Torch::ConstantIntOp>(
1401
+ loc, rewriter.getI64IntegerAttr (strides[dimIdx]));
1402
+ Value inputDimSize =
1403
+ Torch::getTensorDimSize (rewriter, input, dimIdx + 2 );
1404
+ Value stridesMinusOne =
1405
+ rewriter.create <Torch::AtenSubIntOp>(loc, stridesValue, cstOne);
1406
+ Value inputStrides = rewriter.create <Torch::AtenAddIntOp>(
1407
+ loc, inputDimSize, stridesMinusOne);
1408
+ inputStrides = rewriter.create <Torch::AtenFloordivIntOp>(
1409
+ loc, inputStrides, stridesValue);
1410
+ inputStrides =
1411
+ rewriter.create <Torch::AtenSubIntOp>(loc, inputStrides, cstOne);
1412
+ inputStrides = rewriter.create <Torch::AtenMulIntOp>(
1413
+ loc, inputStrides, stridesValue);
1414
+ Value strideWithDilation = rewriter.create <Torch::AtenAddIntOp>(
1415
+ loc, inputStrides, dilatedKernelSize);
1416
+ Value totalPad = rewriter.create <Torch::AtenSubIntOp>(
1417
+ loc, strideWithDilation, inputDimSize);
1418
+
1419
+ // totalPad = totalPad > 0 ? totalPad : 0;
1420
+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1421
+ loc, rewriter.getI64IntegerAttr (0 ));
1422
+ totalPad =
1423
+ rewriter.create <Torch::PrimMaxIntOp>(loc, totalPad, cstZero);
1424
+
1425
+ // padding[dimIdx] =
1426
+ // isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1427
+ // padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
1428
+ Value cstTwo = rewriter.create <Torch::ConstantIntOp>(
1429
+ loc, rewriter.getI64IntegerAttr (2 ));
1430
+ if (isSameLower) {
1431
+ auto padPlusOne =
1432
+ rewriter.create <Torch::AtenAddIntOp>(loc, totalPad, cstOne);
1433
+ paddingValues[dimIdx] = rewriter.create <Torch::AtenFloordivIntOp>(
1434
+ loc, padPlusOne, cstTwo);
1435
+ } else {
1436
+ paddingValues[dimIdx] = rewriter.create <Torch::AtenFloordivIntOp>(
1437
+ loc, totalPad, cstTwo);
1438
+ }
1439
+ paddingValues[spatialRank + dimIdx] =
1440
+ rewriter.create <Torch::AtenSubIntOp>(loc, totalPad,
1441
+ paddingValues[dimIdx]);
1390
1442
}
1391
1443
}
1392
- if (padding.size () != rank - 2 && padding.size () != 2 * (rank - 2 )) {
1444
+
1445
+ if (paddingValues.size () != rank - 2 &&
1446
+ paddingValues.size () != 2 * (rank - 2 )) {
1393
1447
return rewriter.notifyMatchFailure (
1394
1448
binder.op , " padding list size does not match the number of axes" );
1395
1449
}
@@ -1398,11 +1452,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1398
1452
cstOutputPadding;
1399
1453
Value paddedInput = input;
1400
1454
Value paddingList;
1401
- if (padding.size () != 2 * (rank - 2 )) {
1402
- for (int64_t i : padding) {
1403
- cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
1404
- loc, rewriter.getI64IntegerAttr (i)));
1405
- }
1455
+
1456
+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1457
+ loc, rewriter.getI64IntegerAttr (0 ));
1458
+
1459
+ if (paddingValues.size () != 2 * (rank - 2 )) {
1460
+ cstPadding = paddingValues;
1406
1461
paddingList = rewriter.create <Torch::PrimListConstructOp>(
1407
1462
loc,
1408
1463
Torch::ListType::get (
@@ -1418,17 +1473,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1418
1473
// rightmost dim start and end, then next to last, and so on, e.g. {l,
1419
1474
// r, t, b}.
1420
1475
bool matchedPads = true ;
1421
- for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
1422
- if (padding[i] != padding[i + (padding.size () / 2 )]) {
1476
+ for (unsigned i = 0 ; i < paddingValues.size () / 2 ; i++) {
1477
+ int64_t padLow, padHigh;
1478
+ if (!matchPattern (paddingValues[i],
1479
+ Torch::m_TorchConstantInt (&padLow)) ||
1480
+ !matchPattern (paddingValues[i + (paddingValues.size () / 2 )],
1481
+ Torch::m_TorchConstantInt (&padHigh)) ||
1482
+ padLow != padHigh) {
1423
1483
matchedPads = false ;
1424
1484
break ;
1425
1485
}
1426
1486
}
1427
1487
if (matchedPads) {
1428
- for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
1429
- cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
1430
- loc, rewriter.getI64IntegerAttr (padding[i])));
1431
- }
1488
+ for (unsigned i = 0 ; i < paddingValues.size () / 2 ; i++)
1489
+ cstPadding.push_back (paddingValues[i]);
1432
1490
paddingList = rewriter.create <Torch::PrimListConstructOp>(
1433
1491
loc,
1434
1492
Torch::ListType::get (
@@ -1437,16 +1495,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1437
1495
} else {
1438
1496
SmallVector<Value> padsRearrange;
1439
1497
SmallVector<Value> inputPaddingList;
1440
- for (uint32_t i = 0 ; i < padding.size () / 2 ; i++) {
1441
- padsRearrange.emplace_back (rewriter.create <Torch::ConstantIntOp>(
1442
- loc, rewriter.getI64IntegerAttr (
1443
- padding[padding.size () / 2 - i - 1 ])));
1444
- padsRearrange.emplace_back (rewriter.create <Torch::ConstantIntOp>(
1445
- loc,
1446
- rewriter.getI64IntegerAttr (padding[padding.size () - i - 1 ])));
1447
- inputPaddingList.emplace_back (
1448
- rewriter.create <Torch::ConstantIntOp>(
1449
- loc, rewriter.getI64IntegerAttr (0 )));
1498
+ for (uint32_t i = 0 ; i < paddingValues.size () / 2 ; i++) {
1499
+ padsRearrange.emplace_back (
1500
+ paddingValues[paddingValues.size () / 2 - i - 1 ]);
1501
+ padsRearrange.emplace_back (
1502
+ (paddingValues[paddingValues.size () - i - 1 ]));
1503
+ inputPaddingList.emplace_back (cstZero);
1450
1504
}
1451
1505
// The conv op itself will have no padding since the actual padding
1452
1506
// is performed using the torch.pad preceding it.
@@ -1468,23 +1522,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1468
1522
Value constantValue;
1469
1523
1470
1524
if (isa<IntegerType>(inputTensorType.getDtype ()))
1471
- constantValue = rewriter.create <Torch::ConstantIntOp>(
1472
- loc, rewriter.getI64IntegerAttr (0 ));
1525
+ constantValue = cstZero;
1473
1526
if (isa<FloatType>(inputTensorType.getDtype ()))
1474
1527
constantValue = rewriter.create <Torch::ConstantFloatOp>(
1475
1528
loc, rewriter.getF64FloatAttr (0 .0f ));
1529
+
1530
+ auto getPadOutputSizeForInput = [&](int64_t low, int64_t high,
1531
+ int64_t inputSize) {
1532
+ int64_t padLow, padHigh;
1533
+ if (inputSize == Torch::kUnknownSize ||
1534
+ !matchPattern (paddingValues[low],
1535
+ Torch::m_TorchConstantInt (&padLow)) ||
1536
+ !matchPattern (paddingValues[high],
1537
+ Torch::m_TorchConstantInt (&padHigh)))
1538
+ return Torch::kUnknownSize ;
1539
+ return inputSize + padLow + padHigh;
1540
+ };
1541
+
1476
1542
// Pad output shape must be computed explicitly from the pad values
1543
+ // for static dims
1477
1544
SmallVector<int64_t > newInputShape (inputTensorType.getSizes ());
1478
- for (uint32_t i = 0 ; i < padding .size () / 2 ; i++) {
1479
- newInputShape[2 + i] +=
1480
- padding[i] + padding[(padding .size () / 2 ) + i] ;
1545
+ for (uint32_t i = 0 ; i < paddingValues .size () / 2 ; i++) {
1546
+ newInputShape[2 + i] = getPadOutputSizeForInput (
1547
+ i, (paddingValues .size () / 2 ) + i, newInputShape[ 2 + i]) ;
1481
1548
}
1549
+
1482
1550
auto padTy = rewriter.getType <Torch::ValueTensorType>(
1483
1551
newInputShape, inputTensorType.getDtype ());
1484
1552
paddedInput = rewriter.create <Torch::AtenPadOp>(
1485
1553
loc, padTy, input, padsSizeList, modeVal, constantValue);
1486
1554
}
1487
1555
}
1556
+
1488
1557
for (int64_t i : dilations) {
1489
1558
cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
1490
1559
loc, rewriter.getI64IntegerAttr (i)));
@@ -1493,8 +1562,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1493
1562
cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
1494
1563
loc, rewriter.getI64IntegerAttr (i)));
1495
1564
}
1496
- Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1497
- loc, rewriter.getI64IntegerAttr (0 ));
1565
+
1498
1566
cstOutputPadding = {cstZero, cstZero};
1499
1567
1500
1568
Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
0 commit comments