7
7
//
8
8
// ===----------------------------------------------------------------------===//
9
9
10
- #include " mlir/IR/DialectResourceBlobManager.h"
11
10
#include " torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
12
11
#include " torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
13
12
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
14
13
#include " torch-mlir/Dialect/Torch/Utils/Utils.h"
15
- #include " llvm/Support/FormatVariadic.h"
16
14
#include < numeric>
17
15
18
16
using namespace mlir ;
@@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1292
1290
});
1293
1291
patterns.onOp (
1294
1292
" Conv" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1293
+ Location loc = binder.getLoc ();
1295
1294
Torch::ValueTensorType resultType;
1296
1295
Value input, weight;
1297
1296
int64_t group;
@@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1316
1315
binder.op ,
1317
1316
" unsupported conversion: kernel_shape list size should have "
1318
1317
" number of values equal to weight_rank - 2" );
1319
- } else {
1320
- for (unsigned i = 0 ; i < kernelShape.size (); i++) {
1321
- if (weightShape[i + 2 ] != kernelShape[i]) {
1322
- return rewriter.notifyMatchFailure (
1323
- binder.op , " unsupported conversion: kernel_shape value "
1324
- " should be equal to the weight tensor shape" );
1325
- }
1326
- }
1327
1318
}
1328
1319
}
1329
1320
@@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1380
1371
ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
1381
1372
padding.resize_for_overwrite (2 * spatialRank);
1382
1373
for (unsigned dimIdx = 0 ; dimIdx < spatialRank; dimIdx++) {
1374
+ if (weightShape[dimIdx + 2 ] == Torch::kUnknownSize ||
1375
+ inputShape[dimIdx + 2 ] == Torch::kUnknownSize )
1376
+ return rewriter.notifyMatchFailure (
1377
+ binder.op ,
1378
+ " expected weight and input tensor to have static shape" );
1383
1379
const int64_t dilatedKernelSize =
1384
1380
dilations[dimIdx] * (weightShape[dimIdx + 2 ] - 1 ) + 1 ;
1385
1381
int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
@@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1405
1401
if (padding.size () != 2 * (rank - 2 )) {
1406
1402
for (int64_t i : padding) {
1407
1403
cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
1408
- binder. getLoc () , rewriter.getI64IntegerAttr (i)));
1404
+ loc , rewriter.getI64IntegerAttr (i)));
1409
1405
}
1410
1406
paddingList = rewriter.create <Torch::PrimListConstructOp>(
1411
- binder. getLoc () ,
1407
+ loc ,
1412
1408
Torch::ListType::get (
1413
1409
Torch::IntType::get (binder.op ->getContext ())),
1414
1410
cstPadding);
@@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1431
1427
if (matchedPads) {
1432
1428
for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
1433
1429
cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
1434
- binder. getLoc () , rewriter.getI64IntegerAttr (padding[i])));
1430
+ loc , rewriter.getI64IntegerAttr (padding[i])));
1435
1431
}
1436
1432
paddingList = rewriter.create <Torch::PrimListConstructOp>(
1437
- binder. getLoc () ,
1433
+ loc ,
1438
1434
Torch::ListType::get (
1439
1435
Torch::IntType::get (binder.op ->getContext ())),
1440
1436
cstPadding);
@@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1443
1439
SmallVector<Value> inputPaddingList;
1444
1440
for (uint32_t i = 0 ; i < padding.size () / 2 ; i++) {
1445
1441
padsRearrange.emplace_back (rewriter.create <Torch::ConstantIntOp>(
1446
- binder. getLoc () , rewriter.getI64IntegerAttr (
1447
- padding[padding.size () / 2 - i - 1 ])));
1442
+ loc , rewriter.getI64IntegerAttr (
1443
+ padding[padding.size () / 2 - i - 1 ])));
1448
1444
padsRearrange.emplace_back (rewriter.create <Torch::ConstantIntOp>(
1449
- binder. getLoc () ,
1445
+ loc ,
1450
1446
rewriter.getI64IntegerAttr (padding[padding.size () - i - 1 ])));
1451
1447
inputPaddingList.emplace_back (
1452
1448
rewriter.create <Torch::ConstantIntOp>(
1453
- binder. getLoc () , rewriter.getI64IntegerAttr (0 )));
1449
+ loc , rewriter.getI64IntegerAttr (0 )));
1454
1450
}
1455
1451
// The conv op itself will have no padding since the actual padding
1456
1452
// is performed using the torch.pad preceding it.
1457
1453
paddingList = rewriter.create <Torch::PrimListConstructOp>(
1458
- binder. getLoc () ,
1454
+ loc ,
1459
1455
Torch::ListType::get (
1460
1456
Torch::IntType::get (binder.op ->getContext ())),
1461
1457
inputPaddingList);
1462
1458
Value padsSizeList =
1463
1459
rewriter
1464
1460
.create <Torch::PrimListConstructOp>(
1465
- binder. getLoc () ,
1461
+ loc ,
1466
1462
Torch::ListType::get (
1467
1463
rewriter.getType <Torch::IntType>()),
1468
1464
padsRearrange)
1469
1465
.getResult ();
1470
1466
Value modeVal = rewriter.create <Torch::ConstantStrOp>(
1471
- binder. getLoc () , rewriter.getStringAttr (" constant" ));
1467
+ loc , rewriter.getStringAttr (" constant" ));
1472
1468
Value constantValue;
1473
1469
1474
1470
if (isa<IntegerType>(inputTensorType.getDtype ()))
1475
1471
constantValue = rewriter.create <Torch::ConstantIntOp>(
1476
- binder. getLoc () , rewriter.getI64IntegerAttr (0 ));
1472
+ loc , rewriter.getI64IntegerAttr (0 ));
1477
1473
if (isa<FloatType>(inputTensorType.getDtype ()))
1478
1474
constantValue = rewriter.create <Torch::ConstantFloatOp>(
1479
- binder. getLoc () , rewriter.getF64FloatAttr (0 .0f ));
1475
+ loc , rewriter.getF64FloatAttr (0 .0f ));
1480
1476
// Pad output shape must be computed explicitly from the pad values
1481
1477
SmallVector<int64_t > newInputShape (inputTensorType.getSizes ());
1482
1478
for (uint32_t i = 0 ; i < padding.size () / 2 ; i++) {
@@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
1486
1482
auto padTy = rewriter.getType <Torch::ValueTensorType>(
1487
1483
newInputShape, inputTensorType.getDtype ());
1488
1484
paddedInput = rewriter.create <Torch::AtenPadOp>(
1489
- binder.getLoc (), padTy, input, padsSizeList, modeVal,
1490
- constantValue);
1485
+ loc, padTy, input, padsSizeList, modeVal, constantValue);
1491
1486
}
1492
1487
}
1493
1488
for (int64_t i : dilations) {
1494
1489
cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
1495
- binder. getLoc () , rewriter.getI64IntegerAttr (i)));
1490
+ loc , rewriter.getI64IntegerAttr (i)));
1496
1491
}
1497
1492
for (int64_t i : strides) {
1498
1493
cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
1499
- binder. getLoc () , rewriter.getI64IntegerAttr (i)));
1494
+ loc , rewriter.getI64IntegerAttr (i)));
1500
1495
}
1501
1496
Value cstZero = rewriter.create <Torch::ConstantIntOp>(
1502
- binder. getLoc () , rewriter.getI64IntegerAttr (0 ));
1497
+ loc , rewriter.getI64IntegerAttr (0 ));
1503
1498
cstOutputPadding = {cstZero, cstZero};
1504
1499
1505
1500
Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
1506
- binder. getLoc () ,
1501
+ loc ,
1507
1502
Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1508
1503
cstDilations);
1509
1504
Value stridesList = rewriter.create <Torch::PrimListConstructOp>(
1510
- binder. getLoc () ,
1505
+ loc ,
1511
1506
Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1512
1507
cstStrides);
1513
1508
Value outputPaddingList = rewriter.create <Torch::PrimListConstructOp>(
1514
- binder. getLoc () ,
1509
+ loc ,
1515
1510
Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1516
1511
cstOutputPadding);
1517
- Value transposed =
1518
- rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
1512
+ Value transposed = rewriter.create <Torch::ConstantBoolOp>(loc, false );
1519
1513
Value bias;
1520
1514
if (binder.op ->getNumOperands () == 3 ) {
1521
1515
if (binder.tensorOperandAtIndex (bias, 2 )) {
1522
1516
return failure ();
1523
1517
}
1524
1518
} else {
1525
- bias = rewriter.create <Torch::ConstantNoneOp>(binder. getLoc () );
1519
+ bias = rewriter.create <Torch::ConstantNoneOp>(loc );
1526
1520
}
1527
1521
Value cstGroup = rewriter.create <Torch::ConstantIntOp>(
1528
- binder. getLoc () , rewriter.getI64IntegerAttr (group));
1522
+ loc , rewriter.getI64IntegerAttr (group));
1529
1523
1530
1524
rewriter.replaceOpWithNewOp <Torch::AtenConvolutionOp>(
1531
1525
binder.op , resultType, paddedInput, weight, bias, stridesList,
0 commit comments