@@ -4534,6 +4534,133 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
45344534 return success ();
45354535}
45364536
4537+ namespace {
4538+
4539+ // Drops delinearization indices that correspond to unit-extent basis
4540+ struct DropUnitExtentBasis
4541+ : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4542+ using OpRewritePattern::OpRewritePattern;
4543+
4544+ LogicalResult matchAndRewrite (affine::AffineDelinearizeIndexOp delinearizeOp,
4545+ PatternRewriter &rewriter) const override {
4546+ SmallVector<Value> replacements (delinearizeOp->getNumResults (), nullptr );
4547+ std::optional<Value> zero = std::nullopt ;
4548+ Location loc = delinearizeOp->getLoc ();
4549+ auto getZero = [&]() -> Value {
4550+ if (!zero)
4551+ zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
4552+ return zero.value ();
4553+ };
4554+
4555+ // Replace all indices corresponding to unit-extent basis with 0.
4556+ // Remaining basis can be used to get a new `affine.delinearize_index` op.
4557+ SmallVector<Value> newOperands;
4558+ for (auto [index, basis] : llvm::enumerate (delinearizeOp.getBasis ())) {
4559+ if (matchPattern (basis, m_One ()))
4560+ replacements[index] = getZero ();
4561+ else
4562+ newOperands.push_back (basis);
4563+ }
4564+
4565+ if (newOperands.size () == delinearizeOp.getBasis ().size ())
4566+ return failure ();
4567+
4568+ if (!newOperands.empty ()) {
4569+ auto newDelinearizeOp = rewriter.create <affine::AffineDelinearizeIndexOp>(
4570+ loc, delinearizeOp.getLinearIndex (), newOperands);
4571+ int newIndex = 0 ;
4572+ // Map back the new delinearized indices to the values they replace.
4573+ for (auto &replacement : replacements) {
4574+ if (replacement)
4575+ continue ;
4576+ replacement = newDelinearizeOp->getResult (newIndex++);
4577+ }
4578+ }
4579+
4580+ rewriter.replaceOp (delinearizeOp, replacements);
4581+ return success ();
4582+ }
4583+ };
4584+
4585+ // / Drop delinearization pattern related to loops in the following way
4586+ // /
4587+ // / ```
4588+ // / <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4589+ // / %0 = affine.delinearize_index %iv into (%ub) : index
4590+ // / <some_use>(%0)
4591+ // / }
4592+ // / ```
4593+ // /
4594+ // / can be canonicalized to
4595+ // /
4596+ // / ```
4597+ // / <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4598+ // / <some_use>(%iv)
4599+ // / }
4600+ // / ```
4601+ struct DropDelinearizeOfSingleLoop
4602+ : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4603+ using OpRewritePattern::OpRewritePattern;
4604+
4605+ LogicalResult matchAndRewrite (affine::AffineDelinearizeIndexOp delinearizeOp,
4606+ PatternRewriter &rewriter) const override {
4607+ auto basis = delinearizeOp.getBasis ();
4608+ if (basis.size () != 1 )
4609+ return failure ();
4610+
4611+ // Check that the `linear_index` is an induction variable.
4612+ auto inductionVar = cast<BlockArgument>(delinearizeOp.getLinearIndex ());
4613+ if (!inductionVar)
4614+ return failure ();
4615+
4616+ // Check that the parent is a `LoopLikeOpInterface`.
4617+ auto loopLikeOp = cast<LoopLikeOpInterface>(
4618+ inductionVar.getParentRegion ()->getParentOp ());
4619+ if (!loopLikeOp)
4620+ return failure ();
4621+
4622+ // Check that loop is unit-rank and that the `linear_index` is the induction
4623+ // variable.
4624+ auto inductionVars = loopLikeOp.getLoopInductionVars ();
4625+ if (!inductionVars || inductionVars->size () != 1 ||
4626+ inductionVars->front () != inductionVar) {
4627+ return rewriter.notifyMatchFailure (
4628+ delinearizeOp, " `linear_index` is not loop induction variable" );
4629+ }
4630+
4631+ // Check that the upper-bound is the basis.
4632+ auto upperBounds = loopLikeOp.getLoopUpperBounds ();
4633+ if (!upperBounds || upperBounds->size () != 1 ||
4634+ upperBounds->front () != getAsOpFoldResult (basis.front ())) {
4635+ return rewriter.notifyMatchFailure (delinearizeOp,
4636+ " `basis` is not upper bound" );
4637+ }
4638+
4639+ // Check that the lower bound is zero.
4640+ auto lowerBounds = loopLikeOp.getLoopLowerBounds ();
4641+ if (!lowerBounds || lowerBounds->size () != 1 ||
4642+ !isZeroIndex (lowerBounds->front ())) {
4643+ return rewriter.notifyMatchFailure (delinearizeOp,
4644+ " loop lower bound is not zero" );
4645+ }
4646+
4647+ // Check that the step is one.
4648+ auto steps = loopLikeOp.getLoopSteps ();
4649+ if (!steps || steps->size () != 1 || !isConstantIntValue (steps->front (), 1 ))
4650+ return rewriter.notifyMatchFailure (delinearizeOp, " loop step is not one" );
4651+
4652+ rewriter.replaceOp (delinearizeOp, inductionVar);
4653+ return success ();
4654+ }
4655+ };
4656+
4657+ } // namespace
4658+
4659+ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns (
4660+ RewritePatternSet &patterns, MLIRContext *context) {
4661+ patterns.insert <DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
4662+ }
4663+
45374664// ===----------------------------------------------------------------------===//
45384665// TableGen'd op method definitions
45394666// ===----------------------------------------------------------------------===//
0 commit comments