@@ -5199,7 +5199,11 @@ struct CancelLinearizeOfDelinearizePortion final
51995199 return rewriter.notifyMatchFailure (
52005200 linearizeOp, " no run of delinearize outputs to deal with" );
52015201
5202- SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
5202+ // Record all the delinearize replacements so we can do them after creating
5203+ // the new linearization operation, since the new operation might use
5204+ // outputs of something we're replacing.
5205+ SmallVector<SmallVector<Value>> delinearizeReplacements;
5206+
52035207 SmallVector<Value> newIndex;
52045208 newIndex.reserve (numLinArgs);
52055209 SmallVector<OpFoldResult> newBasis;
@@ -5212,18 +5216,26 @@ struct CancelLinearizeOfDelinearizePortion final
52125216 // Update here so we don't forget this during early continues
52135217 prevMatchEnd = m.linStart + m.length ;
52145218
5219+ PatternRewriter::InsertionGuard g (rewriter);
5220+ rewriter.setInsertionPoint (m.delinearize );
5221+
5222+ ArrayRef<OpFoldResult> basisToMerge =
5223+ linBasisRef.slice (m.linStart , m.length );
52155224 // We use the slice from the linearize's basis above because of the
52165225 // "bounds inferred from `disjoint`" case above.
52175226 OpFoldResult newSize =
5218- computeProduct (linearizeOp.getLoc (), rewriter,
5219- linBasisRef.slice (m.linStart , m.length ));
5227+ computeProduct (linearizeOp.getLoc (), rewriter, basisToMerge);
52205228
52215229 // Trivial case where we can just skip past the delinearize all together
52225230 if (m.length == m.delinearize .getNumResults ()) {
52235231 newIndex.push_back (m.delinearize .getLinearIndex ());
52245232 newBasis.push_back (newSize);
5233+ // Pad out set of replacements so we don't do anything with this one.
5234+ delinearizeReplacements.push_back (SmallVector<Value>());
52255235 continue ;
52265236 }
5237+
5238+ SmallVector<Value> newDelinResults;
52275239 SmallVector<OpFoldResult> newDelinBasis = m.delinearize .getPaddedBasis ();
52285240 newDelinBasis.erase (newDelinBasis.begin () + m.delinStart ,
52295241 newDelinBasis.begin () + m.delinStart + m.length );
@@ -5232,31 +5244,39 @@ struct CancelLinearizeOfDelinearizePortion final
52325244 m.delinearize .getLoc (), m.delinearize .getLinearIndex (),
52335245 newDelinBasis);
52345246
5247+ // Since there may be other uses of the indices we just merged together,
5248+ // create a residual affine.delinearize_index that delinearizes the
5249+ // merged output into its component parts.
5250+ Value combinedElem = newDelinearize.getResult (m.delinStart );
5251+ auto residualDelinearize = rewriter.create <AffineDelinearizeIndexOp>(
5252+ m.delinearize .getLoc (), combinedElem, basisToMerge);
5253+
52355254 // Swap all the uses of the unaffected delinearize outputs to the new
52365255 // delinearization so that the old code can be removed if this
52375256 // linearize_index is the only user of the merged results.
5257+ llvm::append_range (newDelinResults,
5258+ newDelinearize.getResults ().take_front (m.delinStart ));
5259+ llvm::append_range (newDelinResults, residualDelinearize.getResults ());
52385260 llvm::append_range (
5239- delinearizeReplacements,
5240- llvm::zip_equal (
5241- m.delinearize .getResults ().take_front (m.delinStart ),
5242- newDelinearize.getResults ().take_front (m.delinStart )));
5243- llvm::append_range (
5244- delinearizeReplacements,
5245- llvm::zip_equal (
5246- m.delinearize .getResults ().drop_front (m.delinStart + m.length ),
5247- newDelinearize.getResults ().drop_front (m.delinStart + 1 )));
5261+ newDelinResults,
5262+ newDelinearize.getResults ().drop_front (m.delinStart + 1 ));
52485263
5249- Value newLinArg = newDelinearize. getResult (m. delinStart );
5250- newIndex.push_back (newLinArg );
5264+ delinearizeReplacements. push_back (newDelinResults );
5265+ newIndex.push_back (combinedElem );
52515266 newBasis.push_back (newSize);
52525267 }
52535268 llvm::append_range (newIndex, multiIndex.drop_front (prevMatchEnd));
52545269 llvm::append_range (newBasis, linBasisRef.drop_front (prevMatchEnd));
52555270 rewriter.replaceOpWithNewOp <AffineLinearizeIndexOp>(
52565271 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint ());
52575272
5258- for (auto [from, to] : delinearizeReplacements)
5259- rewriter.replaceAllUsesWith (from, to);
5273+ for (auto [m, newResults] :
5274+ llvm::zip_equal (matches, delinearizeReplacements)) {
5275+ if (newResults.empty ())
5276+ continue ;
5277+ rewriter.replaceOp (m.delinearize , newResults);
5278+ }
5279+
52605280 return success ();
52615281 }
52625282};
0 commit comments