@@ -30,6 +30,24 @@ using namespace mlir::torch::Torch;
3030// Utilities
3131// ===----------------------------------------------------------------------===//
3232
33+ OpFoldResult genericViewLikeFold (Attribute self, Type resultType) {
34+ auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(self);
35+ if (!selfAttr)
36+ return nullptr ;
37+
38+ auto resultTy = dyn_cast_or_null<ValueTensorType>(resultType);
39+ if (!resultTy || !resultTy.areAllSizesKnown ())
40+ return nullptr ;
41+
42+ if (selfAttr.isSplat ()) {
43+ return SplatElementsAttr::get (resultTy.toBuiltinTensor (),
44+ selfAttr.getSplatValue <Attribute>());
45+ }
46+ return DenseElementsAttr::get (
47+ resultTy.toBuiltinTensor (),
48+ llvm::to_vector (selfAttr.getValues <Attribute>()));
49+ }
50+
3351Value mlir::torch::Torch::adjustStaticInformation (OpBuilder &builder,
3452 Location loc, Value value,
3553 Type desiredType,
@@ -1049,6 +1067,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
10491067// ===----------------------------------------------------------------------===//
10501068
10511069OpFoldResult AtenViewOp::fold (FoldAdaptor adaptor) {
1070+ if (auto genericFold = genericViewLikeFold (adaptor.getSelf (), getType ()))
1071+ return genericFold;
10521072 auto inputType = dyn_cast<BaseTensorType>(getOperand (0 ).getType ());
10531073 if (!inputType || !inputType.hasSizes () || inputType.getSizes ().size () != 1 )
10541074 return nullptr ;
@@ -2236,10 +2256,22 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
22362256 });
22372257}
22382258
2259+ // ===----------------------------------------------------------------------===//
2260+ // AtenFlattenUsingIntsOp
2261+ // ===----------------------------------------------------------------------===//
2262+
2263+ OpFoldResult AtenFlattenUsingIntsOp::fold (FoldAdaptor adaptor) {
2264+ return genericViewLikeFold (adaptor.getSelf (), getType ());
2265+ }
2266+
22392267// ===----------------------------------------------------------------------===//
22402268// AtenUnflattenIntOp
22412269// ===----------------------------------------------------------------------===//
22422270
2271+ OpFoldResult AtenUnflattenIntOp::fold (FoldAdaptor adaptor) {
2272+ return genericViewLikeFold (adaptor.getSelf (), getType ());
2273+ }
2274+
22432275void AtenUnflattenIntOp::getCanonicalizationPatterns (
22442276 RewritePatternSet &patterns, MLIRContext *context) {
22452277 // if there are only two sizes and one of them is statically 1, then convert
@@ -3737,6 +3769,69 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
37373769 adaptor.getOperands (), [](int64_t a, int64_t b) { return a - b; });
37383770}
37393771
3772+ // ===----------------------------------------------------------------------===//
3773+ // AtenTransposeIntOp
3774+ // ===----------------------------------------------------------------------===//
3775+
3776+ OpFoldResult AtenTransposeIntOp::fold (FoldAdaptor adaptor) {
3777+ // first check for no-op
3778+ IntegerAttr dim0 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim0 ());
3779+ IntegerAttr dim1 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim1 ());
3780+ if (!dim0 || !dim1)
3781+ return nullptr ;
3782+ int64_t _dim0 = dim0.getValue ().getSExtValue ();
3783+ int64_t _dim1 = dim1.getValue ().getSExtValue ();
3784+ auto selfTy = dyn_cast<ValueTensorType>(getSelf ().getType ());
3785+ if (!selfTy || !selfTy.hasSizes ())
3786+ return nullptr ;
3787+ int64_t rank = selfTy.getSizes ().size ();
3788+ _dim0 = toPositiveDim (_dim0, rank);
3789+ _dim1 = toPositiveDim (_dim1, rank);
3790+ if (!isValidDim (_dim0, rank) || !isValidDim (_dim1, rank))
3791+ return nullptr ;
3792+ // if dims are the same, return self
3793+ if (_dim0 == _dim1)
3794+ return getSelf ();
3795+
3796+ // We set a maximum folding size of 16. This is a reasonable upper limit
3797+ // for shape computations.
3798+ constexpr int64_t kMaxFoldSize = 16 ;
3799+ auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf ());
3800+ if (!self || self.getNumElements () > kMaxFoldSize )
3801+ return nullptr ;
3802+ auto resultTy = dyn_cast<ValueTensorType>(getType ());
3803+ if (!selfTy || !resultTy || !selfTy.areAllSizesKnown ())
3804+ return nullptr ;
3805+ if (self.isSplat ())
3806+ return SplatElementsAttr::get (resultTy.toBuiltinTensor (),
3807+ self.getSplatValue <Attribute>());
3808+
3809+ // TODO: add support for rank != 2
3810+ if (rank != 2 )
3811+ return nullptr ;
3812+
3813+ ArrayRef<int64_t > sizes = selfTy.getSizes ();
3814+ auto values = llvm::to_vector (self.getValues <Attribute>());
3815+ // reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0],
3816+ // i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])].
3817+ // e.g., Self size = [4,2]; Trans size = [2,4].
3818+ // reindex(i) = (i % 4)*2 + (i // 4) .
3819+ // i = 0 -> Trans[0,0] -> Self[0,0] -> 0 .
3820+ // i = 1 -> Trans[0,1] -> Self[1,0] -> 2 .
3821+ // i = 2 -> Trans[0,2] -> Self[2,0] -> 4 .
3822+ // i = 3 -> Trans[0,3] -> Self[3,0] -> 6 .
3823+ // i = 4 -> Trans[1,0] -> Self[0,1] -> 1 .
3824+ // i = 5 -> Trans[1,1] -> Self[1,1] -> 3 .
3825+ auto reindex = [&](int64_t i) {
3826+ return (i % sizes[0 ]) * sizes[1 ] + (i / sizes[0 ]);
3827+ };
3828+ SmallVector<Attribute> reordered;
3829+ for (int64_t i = 0 ; i < self.getNumElements (); i++) {
3830+ reordered.push_back (values[reindex (i)]);
3831+ }
3832+ return DenseElementsAttr::get (resultTy.toBuiltinTensor (), reordered);
3833+ }
3834+
37403835// ===----------------------------------------------------------------------===//
37413836// AtenCatOp
37423837// ===----------------------------------------------------------------------===//
@@ -3913,15 +4008,18 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
39134008 // Fold the slice if the output tensor is relatively small, currently
39144009 // coded to 16:
39154010 constexpr int64_t kMaxFold = 16 ;
3916- if (input && start && step && dim && count <= kMaxFold ) {
4011+ if (input && start && step && dim && end && count <= kMaxFold ) {
39174012 int64_t begin = start.getValue ().getSExtValue ();
39184013 int64_t limit = end.getValue ().getSExtValue ();
39194014 int64_t stride = step.getValue ().getSExtValue ();
3920- if (stride < 1 )
3921- return nullptr ;
39224015 begin = begin < 0 ? begin + inType.getSizes ()[dimInt] : begin;
39234016 limit = limit < 0 ? limit + inType.getSizes ()[dimInt] : limit;
4017+ limit = limit < 0 ? -1 : limit;
39244018 limit = std::min (limit, inType.getSizes ()[dimInt]);
4019+ bool validIterArgs =
4020+ (stride > 0 && begin < limit) || (stride < 0 && begin > limit);
4021+ assert (validIterArgs &&
4022+ " aten.slice.Tensor iteration args are statically invalid." );
39254023
39264024 int64_t inputRank = inType.getSizes ().size ();
39274025 llvm::SmallVector<int64_t > inputStrides (inputRank, 1 );
@@ -3934,10 +4032,21 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
39344032 auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
39354033 if (currDim >= inputRank)
39364034 return ;
3937- size_t _begin = (currDim == dimInt) ? begin : 0 ;
3938- size_t _limit = (currDim == dimInt) ? limit : inType.getSizes ()[currDim];
3939- size_t _stride = (currDim == dimInt) ? stride : 1 ;
3940- for (size_t i = _begin; i < _limit; i += _stride) {
4035+ int64_t _stride = (currDim == dimInt) ? stride : 1 ;
4036+ int64_t _begin = (currDim == dimInt) ? begin : 0 ;
4037+ int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes ()[currDim];
4038+ // ensure that the limit is reached exactly (even with negative strides)
4039+ // E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11
4040+ // = 10 + (10-0) % 3 .
4041+ // E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 +
4042+ // (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 .
4043+ // Note: cpp uses true math remainder "n % d = least positive int, x, such
4044+ // that d divides (n - x)"
4045+ int64_t limit_rem = (_limit - _begin) % _stride;
4046+ limit_rem =
4047+ (_stride > 0 || limit_rem == 0 ) ? limit_rem : limit_rem - _stride;
4048+ _limit += limit_rem;
4049+ for (int64_t i = _begin; std::abs (_limit - i) > 0 ; i += _stride) {
39414050 if (currDim == inputRank - 1 ) {
39424051 values.push_back (input.getValues <Attribute>()[currOffset + i]);
39434052 }
@@ -5272,26 +5381,56 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
52725381}
52735382
52745383// ===----------------------------------------------------------------------===//
5275- // AtenMaxPool2dWithIndicesOp
5384+ // AtenMaxPoolWithIndicesOp
52765385// ===----------------------------------------------------------------------===//
52775386
5278- void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns (
5279- RewritePatternSet &patterns, MLIRContext *context) {
5280- patterns.add (+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
5387+ namespace {
5388+
5389+ template <typename OpTy> struct MaxPoolWithoutIndices {
5390+ using type = OpTy;
5391+ };
5392+
5393+ template <> struct MaxPoolWithoutIndices <AtenMaxPool2dWithIndicesOp> {
5394+ using type = AtenMaxPool2dOp;
5395+ };
5396+
5397+ template <> struct MaxPoolWithoutIndices <AtenMaxPool3dWithIndicesOp> {
5398+ using type = AtenMaxPool3dOp;
5399+ };
5400+
5401+ } // namespace
5402+
5403+ template <typename OpTy>
5404+ struct SimplifyMaxPoolWithIndices : public mlir ::OpRewritePattern<OpTy> {
5405+ SimplifyMaxPoolWithIndices (mlir::MLIRContext *context)
5406+ : OpRewritePattern<OpTy>(context, /* benefit=*/ 1 ) {}
5407+
5408+ LogicalResult
5409+ matchAndRewrite (OpTy op, mlir::PatternRewriter &rewriter) const override {
52815410 if (!op.getResult1 ().use_empty ()) {
52825411 return rewriter.notifyMatchFailure (
5283- op, " result1 of MaxPool2dWithIndices should be unused" );
5412+ op, " result1 of MaxPoolWithIndices should be unused" );
52845413 }
52855414
5286- Value result = rewriter.create <Torch::AtenMaxPool2dOp >(
5415+ Value result = rewriter.create <typename MaxPoolWithoutIndices<OpTy>::type >(
52875416 op->getLoc (), op.getResult0 ().getType (), op.getSelf (),
52885417 op.getKernelSize (), op.getStride (), op.getPadding (), op.getDilation (),
52895418 op.getCeilMode ());
52905419
52915420 op.getResult0 ().replaceAllUsesWith (result);
52925421 rewriter.eraseOp (op);
52935422 return success ();
5294- });
5423+ }
5424+ };
5425+
5426+ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns (
5427+ RewritePatternSet &patterns, MLIRContext *context) {
5428+ patterns.add <SimplifyMaxPoolWithIndices<AtenMaxPool2dWithIndicesOp>>(context);
5429+ }
5430+
5431+ void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns (
5432+ RewritePatternSet &patterns, MLIRContext *context) {
5433+ patterns.add <SimplifyMaxPoolWithIndices<AtenMaxPool3dWithIndicesOp>>(context);
52955434}
52965435
52975436// ===----------------------------------------------------------------------===//
0 commit comments