Skip to content

Commit fe2f649

Browse files
[ONNX] Remove kernel shape and weight shape equivalence check from Onnx.Conv lowering (#3869)
This commit removes the equivalence check for kernel shape and weight shape from the Onnx.conv lowering since those checks seem to be of no use (not sure why were they part of the lowering in the first place). Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 06d1789 commit fe2f649

File tree

1 file changed

+29
-35
lines changed

1 file changed

+29
-35
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10-
#include "mlir/IR/DialectResourceBlobManager.h"
1110
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
1211
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
1312
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1413
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
15-
#include "llvm/Support/FormatVariadic.h"
1614
#include <numeric>
1715

1816
using namespace mlir;
@@ -1292,6 +1290,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
12921290
});
12931291
patterns.onOp(
12941292
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1293+
Location loc = binder.getLoc();
12951294
Torch::ValueTensorType resultType;
12961295
Value input, weight;
12971296
int64_t group;
@@ -1316,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13161315
binder.op,
13171316
"unsupported conversion: kernel_shape list size should have "
13181317
"number of values equal to weight_rank - 2");
1319-
} else {
1320-
for (unsigned i = 0; i < kernelShape.size(); i++) {
1321-
if (weightShape[i + 2] != kernelShape[i]) {
1322-
return rewriter.notifyMatchFailure(
1323-
binder.op, "unsupported conversion: kernel_shape value "
1324-
"should be equal to the weight tensor shape");
1325-
}
1326-
}
13271318
}
13281319
}
13291320

@@ -1380,6 +1371,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
13801371
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
13811372
padding.resize_for_overwrite(2 * spatialRank);
13821373
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
1374+
if (weightShape[dimIdx + 2] == Torch::kUnknownSize ||
1375+
inputShape[dimIdx + 2] == Torch::kUnknownSize)
1376+
return rewriter.notifyMatchFailure(
1377+
binder.op,
1378+
"expected weight and input tensor to have static shape");
13831379
const int64_t dilatedKernelSize =
13841380
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
13851381
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
@@ -1405,10 +1401,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14051401
if (padding.size() != 2 * (rank - 2)) {
14061402
for (int64_t i : padding) {
14071403
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
1408-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
1404+
loc, rewriter.getI64IntegerAttr(i)));
14091405
}
14101406
paddingList = rewriter.create<Torch::PrimListConstructOp>(
1411-
binder.getLoc(),
1407+
loc,
14121408
Torch::ListType::get(
14131409
Torch::IntType::get(binder.op->getContext())),
14141410
cstPadding);
@@ -1431,10 +1427,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14311427
if (matchedPads) {
14321428
for (unsigned i = 0; i < padding.size() / 2; i++) {
14331429
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
1434-
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
1430+
loc, rewriter.getI64IntegerAttr(padding[i])));
14351431
}
14361432
paddingList = rewriter.create<Torch::PrimListConstructOp>(
1437-
binder.getLoc(),
1433+
loc,
14381434
Torch::ListType::get(
14391435
Torch::IntType::get(binder.op->getContext())),
14401436
cstPadding);
@@ -1443,40 +1439,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14431439
SmallVector<Value> inputPaddingList;
14441440
for (uint32_t i = 0; i < padding.size() / 2; i++) {
14451441
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
1446-
binder.getLoc(), rewriter.getI64IntegerAttr(
1447-
padding[padding.size() / 2 - i - 1])));
1442+
loc, rewriter.getI64IntegerAttr(
1443+
padding[padding.size() / 2 - i - 1])));
14481444
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
1449-
binder.getLoc(),
1445+
loc,
14501446
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
14511447
inputPaddingList.emplace_back(
14521448
rewriter.create<Torch::ConstantIntOp>(
1453-
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
1449+
loc, rewriter.getI64IntegerAttr(0)));
14541450
}
14551451
// The conv op itself will have no padding since the actual padding
14561452
// is performed using the torch.pad preceding it.
14571453
paddingList = rewriter.create<Torch::PrimListConstructOp>(
1458-
binder.getLoc(),
1454+
loc,
14591455
Torch::ListType::get(
14601456
Torch::IntType::get(binder.op->getContext())),
14611457
inputPaddingList);
14621458
Value padsSizeList =
14631459
rewriter
14641460
.create<Torch::PrimListConstructOp>(
1465-
binder.getLoc(),
1461+
loc,
14661462
Torch::ListType::get(
14671463
rewriter.getType<Torch::IntType>()),
14681464
padsRearrange)
14691465
.getResult();
14701466
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
1471-
binder.getLoc(), rewriter.getStringAttr("constant"));
1467+
loc, rewriter.getStringAttr("constant"));
14721468
Value constantValue;
14731469

14741470
if (isa<IntegerType>(inputTensorType.getDtype()))
14751471
constantValue = rewriter.create<Torch::ConstantIntOp>(
1476-
binder.getLoc(), rewriter.getI64IntegerAttr(0));
1472+
loc, rewriter.getI64IntegerAttr(0));
14771473
if (isa<FloatType>(inputTensorType.getDtype()))
14781474
constantValue = rewriter.create<Torch::ConstantFloatOp>(
1479-
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
1475+
loc, rewriter.getF64FloatAttr(0.0f));
14801476
// Pad output shape must be computed explicitly from the pad values
14811477
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
14821478
for (uint32_t i = 0; i < padding.size() / 2; i++) {
@@ -1486,46 +1482,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
14861482
auto padTy = rewriter.getType<Torch::ValueTensorType>(
14871483
newInputShape, inputTensorType.getDtype());
14881484
paddedInput = rewriter.create<Torch::AtenPadOp>(
1489-
binder.getLoc(), padTy, input, padsSizeList, modeVal,
1490-
constantValue);
1485+
loc, padTy, input, padsSizeList, modeVal, constantValue);
14911486
}
14921487
}
14931488
for (int64_t i : dilations) {
14941489
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
1495-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
1490+
loc, rewriter.getI64IntegerAttr(i)));
14961491
}
14971492
for (int64_t i : strides) {
14981493
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
1499-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
1494+
loc, rewriter.getI64IntegerAttr(i)));
15001495
}
15011496
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
1502-
binder.getLoc(), rewriter.getI64IntegerAttr(0));
1497+
loc, rewriter.getI64IntegerAttr(0));
15031498
cstOutputPadding = {cstZero, cstZero};
15041499

15051500
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
1506-
binder.getLoc(),
1501+
loc,
15071502
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
15081503
cstDilations);
15091504
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
1510-
binder.getLoc(),
1505+
loc,
15111506
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
15121507
cstStrides);
15131508
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
1514-
binder.getLoc(),
1509+
loc,
15151510
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
15161511
cstOutputPadding);
1517-
Value transposed =
1518-
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
1512+
Value transposed = rewriter.create<Torch::ConstantBoolOp>(loc, false);
15191513
Value bias;
15201514
if (binder.op->getNumOperands() == 3) {
15211515
if (binder.tensorOperandAtIndex(bias, 2)) {
15221516
return failure();
15231517
}
15241518
} else {
1525-
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
1519+
bias = rewriter.create<Torch::ConstantNoneOp>(loc);
15261520
}
15271521
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
1528-
binder.getLoc(), rewriter.getI64IntegerAttr(group));
1522+
loc, rewriter.getI64IntegerAttr(group));
15291523

15301524
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
15311525
binder.op, resultType, paddedInput, weight, bias, stridesList,

0 commit comments

Comments
 (0)