@@ -5109,13 +5109,19 @@ struct CancelLinearizeOfDelinearizePortion final
51095109 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
51105110 using OpRewritePattern::OpRewritePattern;
51115111
5112+ private:
5113+ // Struct representing a case where the cancellation pattern
5114+ // applies. A `Match` means that `length` inputs to the linearize operation
5115+ // starting at `linStart` can be cancelled with `length` outputs of
5116+ // `delinearize`, starting from `delinStart`.
51125117 struct Match {
51135118 AffineDelinearizeIndexOp delinearize;
51145119 unsigned linStart = 0 ;
51155120 unsigned delinStart = 0 ;
51165121 unsigned length = 0 ;
51175122 };
51185123
5124+ public:
51195125 LogicalResult matchAndRewrite (affine::AffineLinearizeIndexOp linearizeOp,
51205126 PatternRewriter &rewriter) const override {
51215127 SmallVector<Match> matches;
@@ -5128,7 +5134,7 @@ struct CancelLinearizeOfDelinearizePortion final
51285134 unsigned linArgIdx = 0 ;
51295135 // We only want to replace one run from the same delinearize op per
51305136 // pattern invocation lest we run into invalidation issues.
5131- llvm::SmallPtrSet<Operation *, 2 > seen ;
5137+ llvm::SmallPtrSet<Operation *, 2 > alreadyMatchedDelinearize ;
51325138 while (linArgIdx < numLinArgs) {
51335139 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
51345140 if (!asResult) {
@@ -5155,37 +5161,37 @@ struct CancelLinearizeOfDelinearizePortion final
51555161 // / - The delinearization doesn't specify a bound, but the linearization
51565162 // / is `disjoint`, which asserts that the bound on the linearization is
51575163 // / correct.
5158- unsigned firstDelinArg = asResult.getResultNumber ();
5164+ unsigned delinArgIdx = asResult.getResultNumber ();
51595165 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis ();
5160- OpFoldResult firstDelinBound = delinBasis[firstDelinArg ];
5166+ OpFoldResult firstDelinBound = delinBasis[delinArgIdx ];
51615167 OpFoldResult firstLinBound = linBasis[linArgIdx];
51625168 bool boundsMatch = firstDelinBound == firstLinBound;
5163- bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0 ;
5169+ bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0 ;
51645170 bool knownByDisjoint =
5165- linearizeOp.getDisjoint () && firstDelinArg == 0 && !firstDelinBound;
5171+ linearizeOp.getDisjoint () && delinArgIdx == 0 && !firstDelinBound;
51665172 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
51675173 linArgIdx++;
51685174 continue ;
51695175 }
51705176
51715177 unsigned j = 1 ;
51725178 unsigned numDelinOuts = delinearizeOp.getNumResults ();
5173- for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
5179+ for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
51745180 ++j) {
51755181 if (multiIndex[linArgIdx + j] !=
5176- delinearizeOp.getResult (firstDelinArg + j))
5182+ delinearizeOp.getResult (delinArgIdx + j))
51775183 break ;
5178- if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
5184+ if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
51795185 break ;
51805186 }
51815187 // If there're multiple matches against the same delinearize_index,
51825188 // only rewrite the first one we find to prevent invalidations. The next
5183- // ones will be taken caer of by subsequent pattern invocations.
5184- if (j <= 1 || !seen .insert (delinearizeOp).second ) {
5189+ // ones will be taken care of by subsequent pattern invocations.
5190+ if (j <= 1 || !alreadyMatchedDelinearize .insert (delinearizeOp).second ) {
51855191 linArgIdx++;
51865192 continue ;
51875193 }
5188- matches.push_back (Match{delinearizeOp, linArgIdx, firstDelinArg , j});
5194+ matches.push_back (Match{delinearizeOp, linArgIdx, delinArgIdx , j});
51895195 linArgIdx += j;
51905196 }
51915197
0 commit comments