diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 4cf07bc167eab..67d7da622a355 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4666,14 +4666,16 @@ struct DropUnitExtentBasis }; /// If a `affine.delinearize_index`'s input is a `affine.linearize_index -/// disjoint` and the two operations have the same basis, replace the -/// delinearizeation results with the inputs of the `affine.linearize_index` -/// since they are exact inverses of each other. +/// disjoint` and the two operations end with the same basis elements, +/// cancel those parts of the operations out because they are inverses +/// of each other. +/// +/// If the operations have the same basis, cancel them entirely. /// /// The `disjoint` flag is needed on the `affine.linearize_index` because /// otherwise, there is no guarantee that the inputs to the linearization are /// in-bounds the way the outputs of the delinearization would be. -struct CancelDelinearizeOfLinearizeDisjointExact +struct CancelDelinearizeOfLinearizeDisjointExactTail : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -4685,12 +4687,45 @@ struct CancelDelinearizeOfLinearizeDisjointExact return rewriter.notifyMatchFailure(delinearizeOp, "index doesn't come from linearize"); - if (!linearizeOp.getDisjoint() || - linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis()) + if (!linearizeOp.getDisjoint()) + return rewriter.notifyMatchFailure(linearizeOp, "not disjoint"); + + ValueRange linearizeIns = linearizeOp.getMultiIndex(); + // Note: we use the full basis so we don't lose outer bounds later. + SmallVector linearizeBasis = linearizeOp.getMixedBasis(); + SmallVector delinearizeBasis = delinearizeOp.getMixedBasis(); + size_t numMatches = 0; + for (auto [linSize, delinSize] : llvm::zip( + llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) { + if (linSize != delinSize) + break; + ++numMatches; + } + + if (numMatches == 0) return rewriter.notifyMatchFailure( - linearizeOp, "not disjoint or basis doesn't match delinearize"); + delinearizeOp, "final basis element doesn't match linearize"); + + // The easy case: everything lines up and the basis match sup completely. + if (numMatches == linearizeBasis.size() && + numMatches == delinearizeBasis.size() && + linearizeIns.size() == delinearizeOp.getNumResults()) { + rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex()); + return success(); + } - rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex()); + Value newLinearize = rewriter.create( + linearizeOp.getLoc(), linearizeIns.drop_back(numMatches), + ArrayRef{linearizeBasis}.drop_back(numMatches), + linearizeOp.getDisjoint()); + auto newDelinearize = rewriter.create( + delinearizeOp.getLoc(), newLinearize, + ArrayRef{delinearizeBasis}.drop_back(numMatches), + delinearizeOp.hasOuterBound()); + SmallVector mergedResults(newDelinearize.getResults()); + mergedResults.append(linearizeIns.take_back(numMatches).begin(), + linearizeIns.take_back(numMatches).end()); + rewriter.replaceOp(delinearizeOp, mergedResults); return success(); } }; @@ -4698,9 +4733,8 @@ struct CancelDelinearizeOfLinearizeDisjointExact void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns - .insert( - context); + patterns.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index b54a13cffe777..5384977151b47 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1739,6 +1739,24 @@ func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0: // ----- +// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (%[[ARG3]], 4) : index +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[LIN]] into (8) : index, index +// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]] +func.func @cancel_delinearize_linearize_disjoint_partial(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) { + %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index + %1:3 = affine.delinearize_index %0 into (8, %arg4) + : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} + +// ----- + // Without `disjoint`, the cancelation isn't guaranteed to be the identity. // CHECK-LABEL: func @no_cancel_delinearize_linearize_exact( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,