Skip to content

Commit 191404b

Browse files
committed
[AutoBump] Merge with ad9dfe9 (Oct 25)
2 parents f3d07ce + ad9dfe9 commit 191404b

File tree

11 files changed

+467
-104
lines changed

11 files changed

+467
-104
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7352,6 +7352,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices",
73527352
printDefaultTorchOp(printer, *this, 6, 2);
73537353
}
73547354
}];
7355+
let hasCanonicalizer = 1;
73557356
}
73567357

73577358
def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
@@ -8079,6 +8080,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
80798080
printDefaultTorchOp(printer, *this, 3, 1);
80808081
}
80818082
}];
8083+
let hasFolder = 1;
80828084
}
80838085

80848086
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
@@ -9671,6 +9673,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
96719673
printDefaultTorchOp(printer, *this, 3, 1);
96729674
}
96739675
}];
9676+
let hasFolder = 1;
96749677
}
96759678

96769679
def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
@@ -9695,6 +9698,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
96959698
printDefaultTorchOp(printer, *this, 3, 1);
96969699
}
96979700
}];
9701+
let hasFolder = 1;
96989702
let hasCanonicalizer = 1;
96999703
}
97009704

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,9 +1087,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
10871087
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
10881088
return rewriter.notifyMatchFailure(binder.op,
10891089
"auto_pad bind failure");
1090-
if (autoPad != "NOTSET")
1091-
return rewriter.notifyMatchFailure(
1092-
binder.op, "unsupported conversion: auto_pad != NOTSET");
10931090

10941091
Torch::ValueTensorType resultTypeOut;
10951092
Value operand;
@@ -1136,13 +1133,42 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
11361133
return rewriter.notifyMatchFailure(binder.op,
11371134
"dilations bind failure");
11381135

1136+
// set default padding
11391137
if (padding.empty())
11401138
padding.resize(spatial, 0);
11411139
if (strides.empty())
11421140
strides.resize(spatial, 1);
11431141
if (dilations.empty())
11441142
dilations.resize(spatial, 1);
11451143

1144+
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
1145+
1146+
// Padding for the beginning and ending along each spatial axis, it can
1147+
// take any value greater than or equal to 0. The value represent the
1148+
// number of pixels added to the beginning and end part of the
1149+
// corresponding axis. pads format should be as follow [x1_begin,
1150+
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1151+
// at the beginning of axis i and xi_end, the number of pixels added at
1152+
// the end of axis i.
1153+
if (autoPad != "NOTSET" && autoPad != "VALID") {
1154+
const bool isSameLower = autoPad == "SAME_LOWER";
1155+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
1156+
padding.resize_for_overwrite(2 * spatial);
1157+
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
1158+
const int64_t dilatedKernelSize =
1159+
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
1160+
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
1161+
strides[dimIdx] -
1162+
1) *
1163+
strides[dimIdx] +
1164+
dilatedKernelSize - inputShape[dimIdx + 2];
1165+
totalPad = totalPad >= 0 ? totalPad : 0;
1166+
padding[dimIdx] =
1167+
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1168+
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1169+
}
1170+
}
1171+
11461172
// If the padding is symmetric we can push the padding operation to the
11471173
// torch operator.
11481174
if (padding.size() == static_cast<size_t>(2 * spatial)) {

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 153 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ using namespace mlir::torch::Torch;
3030
// Utilities
3131
//===----------------------------------------------------------------------===//
3232

33+
OpFoldResult genericViewLikeFold(Attribute self, Type resultType) {
34+
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(self);
35+
if (!selfAttr)
36+
return nullptr;
37+
38+
auto resultTy = dyn_cast_or_null<ValueTensorType>(resultType);
39+
if (!resultTy || !resultTy.areAllSizesKnown())
40+
return nullptr;
41+
42+
if (selfAttr.isSplat()) {
43+
return SplatElementsAttr::get(resultTy.toBuiltinTensor(),
44+
selfAttr.getSplatValue<Attribute>());
45+
}
46+
return DenseElementsAttr::get(
47+
resultTy.toBuiltinTensor(),
48+
llvm::to_vector(selfAttr.getValues<Attribute>()));
49+
}
50+
3351
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
3452
Location loc, Value value,
3553
Type desiredType,
@@ -1049,6 +1067,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
10491067
//===----------------------------------------------------------------------===//
10501068

10511069
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
1070+
if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType()))
1071+
return genericFold;
10521072
auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
10531073
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
10541074
return nullptr;
@@ -2236,10 +2256,22 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
22362256
});
22372257
}
22382258

2259+
//===----------------------------------------------------------------------===//
2260+
// AtenFlattenUsingIntsOp
2261+
//===----------------------------------------------------------------------===//
2262+
2263+
OpFoldResult AtenFlattenUsingIntsOp::fold(FoldAdaptor adaptor) {
2264+
return genericViewLikeFold(adaptor.getSelf(), getType());
2265+
}
2266+
22392267
//===----------------------------------------------------------------------===//
22402268
// AtenUnflattenIntOp
22412269
//===----------------------------------------------------------------------===//
22422270

2271+
OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) {
2272+
return genericViewLikeFold(adaptor.getSelf(), getType());
2273+
}
2274+
22432275
void AtenUnflattenIntOp::getCanonicalizationPatterns(
22442276
RewritePatternSet &patterns, MLIRContext *context) {
22452277
// if there are only two sizes and one of them is statically 1, then convert
@@ -3737,6 +3769,69 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
37373769
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
37383770
}
37393771

3772+
//===----------------------------------------------------------------------===//
3773+
// AtenTransposeIntOp
3774+
//===----------------------------------------------------------------------===//
3775+
3776+
OpFoldResult AtenTransposeIntOp::fold(FoldAdaptor adaptor) {
3777+
// first check for no-op
3778+
IntegerAttr dim0 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim0());
3779+
IntegerAttr dim1 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim1());
3780+
if (!dim0 || !dim1)
3781+
return nullptr;
3782+
int64_t _dim0 = dim0.getValue().getSExtValue();
3783+
int64_t _dim1 = dim1.getValue().getSExtValue();
3784+
auto selfTy = dyn_cast<ValueTensorType>(getSelf().getType());
3785+
if (!selfTy || !selfTy.hasSizes())
3786+
return nullptr;
3787+
int64_t rank = selfTy.getSizes().size();
3788+
_dim0 = toPositiveDim(_dim0, rank);
3789+
_dim1 = toPositiveDim(_dim1, rank);
3790+
if (!isValidDim(_dim0, rank) || !isValidDim(_dim1, rank))
3791+
return nullptr;
3792+
// if dims are the same, return self
3793+
if (_dim0 == _dim1)
3794+
return getSelf();
3795+
3796+
// We set a maximum folding size of 16. This is a reasonable upper limit
3797+
// for shape computations.
3798+
constexpr int64_t kMaxFoldSize = 16;
3799+
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
3800+
if (!self || self.getNumElements() > kMaxFoldSize)
3801+
return nullptr;
3802+
auto resultTy = dyn_cast<ValueTensorType>(getType());
3803+
if (!selfTy || !resultTy || !selfTy.areAllSizesKnown())
3804+
return nullptr;
3805+
if (self.isSplat())
3806+
return SplatElementsAttr::get(resultTy.toBuiltinTensor(),
3807+
self.getSplatValue<Attribute>());
3808+
3809+
// TODO: add support for rank != 2
3810+
if (rank != 2)
3811+
return nullptr;
3812+
3813+
ArrayRef<int64_t> sizes = selfTy.getSizes();
3814+
auto values = llvm::to_vector(self.getValues<Attribute>());
3815+
// reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0],
3816+
// i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])].
3817+
// e.g., Self size = [4,2]; Trans size = [2,4].
3818+
// reindex(i) = (i % 4)*2 + (i // 4) .
3819+
// i = 0 -> Trans[0,0] -> Self[0,0] -> 0 .
3820+
// i = 1 -> Trans[0,1] -> Self[1,0] -> 2 .
3821+
// i = 2 -> Trans[0,2] -> Self[2,0] -> 4 .
3822+
// i = 3 -> Trans[0,3] -> Self[3,0] -> 6 .
3823+
// i = 4 -> Trans[1,0] -> Self[0,1] -> 1 .
3824+
// i = 5 -> Trans[1,1] -> Self[1,1] -> 3 .
3825+
auto reindex = [&](int64_t i) {
3826+
return (i % sizes[0]) * sizes[1] + (i / sizes[0]);
3827+
};
3828+
SmallVector<Attribute> reordered;
3829+
for (int64_t i = 0; i < self.getNumElements(); i++) {
3830+
reordered.push_back(values[reindex(i)]);
3831+
}
3832+
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), reordered);
3833+
}
3834+
37403835
//===----------------------------------------------------------------------===//
37413836
// AtenCatOp
37423837
//===----------------------------------------------------------------------===//
@@ -3913,15 +4008,18 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
39134008
// Fold the slice if the output tensor is relatively small, currently
39144009
// coded to 16:
39154010
constexpr int64_t kMaxFold = 16;
3916-
if (input && start && step && dim && count <= kMaxFold) {
4011+
if (input && start && step && dim && end && count <= kMaxFold) {
39174012
int64_t begin = start.getValue().getSExtValue();
39184013
int64_t limit = end.getValue().getSExtValue();
39194014
int64_t stride = step.getValue().getSExtValue();
3920-
if (stride < 1)
3921-
return nullptr;
39224015
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
39234016
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
4017+
limit = limit < 0 ? -1 : limit;
39244018
limit = std::min(limit, inType.getSizes()[dimInt]);
4019+
bool validIterArgs =
4020+
(stride > 0 && begin < limit) || (stride < 0 && begin > limit);
4021+
assert(validIterArgs &&
4022+
"aten.slice.Tensor iteration args are statically invalid.");
39254023

39264024
int64_t inputRank = inType.getSizes().size();
39274025
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
@@ -3934,10 +4032,21 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
39344032
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
39354033
if (currDim >= inputRank)
39364034
return;
3937-
size_t _begin = (currDim == dimInt) ? begin : 0;
3938-
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
3939-
size_t _stride = (currDim == dimInt) ? stride : 1;
3940-
for (size_t i = _begin; i < _limit; i += _stride) {
4035+
int64_t _stride = (currDim == dimInt) ? stride : 1;
4036+
int64_t _begin = (currDim == dimInt) ? begin : 0;
4037+
int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
4038+
// ensure that the limit is reached exactly (even with negative strides)
4039+
// E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11
4040+
// = 10 + (10-0) % 3 .
4041+
// E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 +
4042+
// (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 .
4043+
// Note: cpp uses true math remainder "n % d = least positive int, x, such
4044+
// that d divides (n - x)"
4045+
int64_t limit_rem = (_limit - _begin) % _stride;
4046+
limit_rem =
4047+
(_stride > 0 || limit_rem == 0) ? limit_rem : limit_rem - _stride;
4048+
_limit += limit_rem;
4049+
for (int64_t i = _begin; std::abs(_limit - i) > 0; i += _stride) {
39414050
if (currDim == inputRank - 1) {
39424051
values.push_back(input.getValues<Attribute>()[currOffset + i]);
39434052
}
@@ -5272,26 +5381,56 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
52725381
}
52735382

52745383
//===----------------------------------------------------------------------===//
5275-
// AtenMaxPool2dWithIndicesOp
5384+
// AtenMaxPoolWithIndicesOp
52765385
//===----------------------------------------------------------------------===//
52775386

5278-
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
5279-
RewritePatternSet &patterns, MLIRContext *context) {
5280-
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
5387+
namespace {
5388+
5389+
template <typename OpTy> struct MaxPoolWithoutIndices {
5390+
using type = OpTy;
5391+
};
5392+
5393+
template <> struct MaxPoolWithoutIndices<AtenMaxPool2dWithIndicesOp> {
5394+
using type = AtenMaxPool2dOp;
5395+
};
5396+
5397+
template <> struct MaxPoolWithoutIndices<AtenMaxPool3dWithIndicesOp> {
5398+
using type = AtenMaxPool3dOp;
5399+
};
5400+
5401+
} // namespace
5402+
5403+
template <typename OpTy>
5404+
struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern<OpTy> {
5405+
SimplifyMaxPoolWithIndices(mlir::MLIRContext *context)
5406+
: OpRewritePattern<OpTy>(context, /*benefit=*/1) {}
5407+
5408+
LogicalResult
5409+
matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override {
52815410
if (!op.getResult1().use_empty()) {
52825411
return rewriter.notifyMatchFailure(
5283-
op, "result1 of MaxPool2dWithIndices should be unused");
5412+
op, "result1 of MaxPoolWithIndices should be unused");
52845413
}
52855414

5286-
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
5415+
Value result = rewriter.create<typename MaxPoolWithoutIndices<OpTy>::type>(
52875416
op->getLoc(), op.getResult0().getType(), op.getSelf(),
52885417
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
52895418
op.getCeilMode());
52905419

52915420
op.getResult0().replaceAllUsesWith(result);
52925421
rewriter.eraseOp(op);
52935422
return success();
5294-
});
5423+
}
5424+
};
5425+
5426+
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
5427+
RewritePatternSet &patterns, MLIRContext *context) {
5428+
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool2dWithIndicesOp>>(context);
5429+
}
5430+
5431+
void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns(
5432+
RewritePatternSet &patterns, MLIRContext *context) {
5433+
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool3dWithIndicesOp>>(context);
52955434
}
52965435

52975436
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
3232
} // namespace
3333

3434
namespace {
35-
// TODO: Only unroll inside the shape calculation region.
36-
// Maybe do this by only applying patterns and folding greedily on the ops
37-
// inside the region + the shape.calculate op itself?
3835
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
3936
public:
4037
using OpRewritePattern::OpRewritePattern;
4138
LogicalResult matchAndRewrite(PrimLoopOp op,
4239
PatternRewriter &rewriter) const override {
4340
Location loc = op->getLoc();
4441
MLIRContext *context = op->getContext();
42+
// Only unroll loops if they are contained in a shape calculate region.
43+
Region *region = op->getParentRegion();
44+
Operation *parentOp = region->getParentOp();
45+
if (!parentOp || !isa<Torch::ShapeCalculateOp>(parentOp))
46+
return rewriter.notifyMatchFailure(
47+
op, "Loop is not contained in a shape calculation region.");
4548
if (!op.isForLike())
4649
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
4750
int64_t maxTripCount;

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,6 @@
394394
"AtenIntBoolOpModule_basic",
395395
"AtenIntMM_basic",
396396
"AtenItemFpOpModule_basic",
397-
"AtenMatmulQMixedSigni8Transpose_basic",
398-
"AtenMatmulQMixedSigni8_basic",
399-
"AtenMatmulQint8MV_basic",
400-
"AtenMatmulQint8_basic",
401-
"AtenMatmulQint8VM_basic",
402-
"AtenMatmulQint8VV_basic",
403-
"AtenMmQMixedSigni8_basic",
404-
"AtenMmQint8_basic",
405-
"AtenMmQuint8_basic",
406397
"QuantizedReluInt32_basic",
407398
"QuantizedReluInt8_basic",
408399
"QuantizedReluUint8_basic",
@@ -2734,20 +2725,6 @@
27342725
"MultinomialModule2D_basic",
27352726
"MultinomialModule2D_F32",
27362727
"PixelShuffleModuleStaticRank4Float32_basic",
2737-
"ReflectionPad1dModule2dInput_Right",
2738-
"ReflectionPad1dModule2dInput_basic",
2739-
"ReflectionPad1dModule3dInput_Left",
2740-
"ReflectionPad1dModule3dInput_basic",
2741-
"ReflectionPad2dModule_Bottom",
2742-
"ReflectionPad2dModule_Left",
2743-
"ReflectionPad2dModule_Right",
2744-
"ReflectionPad2dModule_Top",
2745-
"ReflectionPad2dModule_basic",
2746-
"ReplicationPad2dModule_basic",
2747-
"ReplicationPad2dModule_bottom0",
2748-
"ReplicationPad2dModule_left0",
2749-
"ReplicationPad2dModule_right0",
2750-
"ReplicationPad2dModule_top0",
27512728
"SliceCopyEndGreaterThanDimSize_Module_basic",
27522729
"SliceCopyNegative_Module_basic",
27532730
"SliceCopyNonZeroDim_Module_basic",

0 commit comments

Comments
 (0)