77#include " mlir/Dialect/SCF/IR/SCF.h"
88#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
99#include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
10- #include " mlir/Transforms/OneToNTypeConversion .h"
10+ #include " mlir/Transforms/DialectConversion .h"
1111
1212using namespace mlir ;
1313using namespace mlir ::sparse_tensor;
1414
15+ // / Assert that the given value range contains a single value and return it.
16+ static Value getSingleValue (ValueRange values) {
17+ assert (values.size () == 1 && " expected single value" );
18+ return values.front ();
19+ }
20+
1521static void convertLevelType (SparseTensorEncodingAttr enc, Level lvl,
1622 SmallVectorImpl<Type> &fields) {
1723 // Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
5460genCoIterateBranchNest (PatternRewriter &rewriter, Location loc, CoIterateOp op,
5561 Value loopCrd,
5662 ArrayRef<std::unique_ptr<SparseIterator>> iters,
57- ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58- if (subCases.empty ())
63+ ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
64+ ArrayRef<Value> userReduc) {
65+ if (newBlocks.empty ())
5966 return userReduc;
6067
6168 // The current branch that we are handling.
62- Region *b = subCases.front ();
69+ Block *newBlock = newBlocks.front ();
70+ Block *oldBlock = oldBlocks.front ();
6371 Value casePred = constantI1 (rewriter, loc, true );
64- I64BitSet caseBits = op.getRegionDefinedSpace (b->getRegionNumber ());
72+ I64BitSet caseBits =
73+ op.getRegionDefinedSpace (newBlock->getParent ()->getRegionNumber ());
6574 for (unsigned i : caseBits.bits ()) {
6675 SparseIterator *it = iters[i].get ();
6776 Value pred = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
8089 for (unsigned idx : caseBits.bits ())
8190 llvm::append_range (blockArgs, iters[idx]->getCursor ());
8291
92+ // Map the old block arguments, because the dialect conversion driver does
93+ // not immediately perform SSA value replacements. This function is still
94+ // seeing the old uses.
8395 IRMapping mapping;
84- for (auto [from, to] :
85- llvm::zip_equal (b->front ().getArguments (), blockArgs)) {
96+ for (auto [from, to] : llvm::zip_equal (oldBlock->getArguments (), blockArgs)) {
8697 mapping.map (from, to);
8798 }
8899
89100 // Clone the region, we can not erase the region now because the same region
90101 // might be a subcase for multiple lattice point.
91- rewriter.cloneRegionBefore (*b , ifOp.getThenRegion (),
102+ rewriter.cloneRegionBefore (*newBlock-> getParent () , ifOp.getThenRegion (),
92103 ifOp.getThenRegion ().begin (), mapping);
104+ // Remove the block arguments, they were already replaced via `mapping`.
105+ ifOp.getThenRegion ().front ().eraseArguments (0 , blockArgs.size ());
93106
94107 // replace sparse_tensor::YieldOp -> scf::YieldOp
95108 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion ().front ().back ());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
101114 // Generates remaining case recursively.
102115 rewriter.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
103116 ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd, iters,
104- subCases.drop_front (), userReduc);
117+ newBlocks.drop_front (),
118+ oldBlocks.drop_front (), userReduc);
105119 if (!res.empty ())
106120 rewriter.create <scf::YieldOp>(loc, res);
107121
@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
119133 if (it->iteratableByFor ()) {
120134 auto [lo, hi] = it->genForCond (rewriter, loc);
121135 Value step = constantIndex (rewriter, loc, 1 );
122- scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, reduc);
136+ scf::ForOp forOp = rewriter.create <scf::ForOp>(
137+ loc, lo, hi, step, reduc,
138+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
139+ // Empty builder function to ensure that no terminator is created.
140+ });
123141 {
124142 OpBuilder::InsertionGuard guard (rewriter);
125- // Erase the implicit yield operation created by ForOp when there is no
126- // yielding values.
127- if (!forOp.getBody ()->empty ())
128- rewriter.eraseOp (&forOp.getBody ()->front ());
129- assert (forOp.getBody ()->empty ());
130-
131143 it->linkNewScope (forOp.getInductionVar ());
132144 rewriter.setInsertionPointToStart (forOp.getBody ());
133145 SmallVector<Value> ret = bodyBuilder (rewriter, loc, forOp.getBodyRegion (),
@@ -178,46 +190,47 @@ namespace {
178190
179191// / Sparse codegen rule for number of entries operator.
180192class ExtractIterSpaceConverter
181- : public OneToNOpConversionPattern <ExtractIterSpaceOp> {
193+ : public OpConversionPattern <ExtractIterSpaceOp> {
182194public:
183- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
195+ using OpConversionPattern::OpConversionPattern ;
184196 LogicalResult
185- matchAndRewrite (ExtractIterSpaceOp op, OpAdaptor adaptor,
186- OneToNPatternRewriter &rewriter) const override {
197+ matchAndRewrite (ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
198+ ConversionPatternRewriter &rewriter) const override {
187199 Location loc = op.getLoc ();
188- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
189200
190201 // Construct the iteration space.
191- SparseIterationSpace space (loc, rewriter, op.getTensor (), 0 ,
202+ SparseIterationSpace space (loc, rewriter,
203+ getSingleValue (adaptor.getTensor ()), 0 ,
192204 op.getLvlRange (), adaptor.getParentIter ());
193205
194206 SmallVector<Value> result = space.toValues ();
195- rewriter.replaceOp (op, result, resultMapping );
207+ rewriter.replaceOpWithMultiple (op, { result} );
196208 return success ();
197209 }
198210};
199211
200212// / Sparse codegen rule for number of entries operator.
201- class ExtractValOpConverter : public OneToNOpConversionPattern <ExtractValOp> {
213+ class ExtractValOpConverter : public OpConversionPattern <ExtractValOp> {
202214public:
203- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
215+ using OpConversionPattern::OpConversionPattern ;
204216 LogicalResult
205- matchAndRewrite (ExtractValOp op, OpAdaptor adaptor,
206- OneToNPatternRewriter &rewriter) const override {
217+ matchAndRewrite (ExtractValOp op, OneToNOpAdaptor adaptor,
218+ ConversionPatternRewriter &rewriter) const override {
207219 Location loc = op.getLoc ();
208220 Value pos = adaptor.getIterator ().back ();
209- Value valBuf = rewriter.create <ToValuesOp>(loc, op.getTensor ());
221+ Value valBuf =
222+ rewriter.create <ToValuesOp>(loc, getSingleValue (adaptor.getTensor ()));
210223 rewriter.replaceOpWithNewOp <memref::LoadOp>(op, valBuf, pos);
211224 return success ();
212225 }
213226};
214227
215- class SparseIterateOpConverter : public OneToNOpConversionPattern <IterateOp> {
228+ class SparseIterateOpConverter : public OpConversionPattern <IterateOp> {
216229public:
217- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
230+ using OpConversionPattern::OpConversionPattern ;
218231 LogicalResult
219- matchAndRewrite (IterateOp op, OpAdaptor adaptor,
220- OneToNPatternRewriter &rewriter) const override {
232+ matchAndRewrite (IterateOp op, OneToNOpAdaptor adaptor,
233+ ConversionPatternRewriter &rewriter) const override {
221234 if (!op.getCrdUsedLvls ().empty ())
222235 return rewriter.notifyMatchFailure (
223236 op, " non-empty coordinates list not implemented." );
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
235248 llvm::append_range (ivs, inits);
236249
237250 // Type conversion on iterate op block.
238- OneToNTypeMapping blockTypeMapping (op.getBody ()->getArgumentTypes ());
251+ unsigned numOrigArgs = op.getBody ()->getArgumentTypes ().size ();
252+ TypeConverter::SignatureConversion signatureConversion (numOrigArgs);
239253 if (failed (typeConverter->convertSignatureArgs (
240- op.getBody ()->getArgumentTypes (), blockTypeMapping )))
254+ op.getBody ()->getArgumentTypes (), signatureConversion )))
241255 return rewriter.notifyMatchFailure (
242256 op, " failed to convert iterate region argurment types" );
243- rewriter.applySignatureConversion (op.getBody (), blockTypeMapping);
244257
245- Block *block = op.getBody ();
258+ Block *block = rewriter.applySignatureConversion (
259+ op.getBody (), signatureConversion, getTypeConverter ());
246260 ValueRange ret = genLoopWithIterator (
247261 rewriter, loc, it.get (), ivs,
248262 [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
263277 return result;
264278 });
265279
266- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
267- rewriter.replaceOp (op, ret, resultMapping);
280+ rewriter.replaceOp (op, ret);
268281 return success ();
269282 }
270283};
271284
272- class SparseCoIterateOpConverter
273- : public OneToNOpConversionPattern<CoIterateOp> {
274- using OneToNOpConversionPattern::OneToNOpConversionPattern;
285+ class SparseCoIterateOpConverter : public OpConversionPattern <CoIterateOp> {
286+ using OpConversionPattern::OpConversionPattern;
275287
276288 LogicalResult
277- matchAndRewrite (CoIterateOp op, OpAdaptor adaptor,
278- OneToNPatternRewriter &rewriter) const override {
289+ matchAndRewrite (CoIterateOp op, OneToNOpAdaptor adaptor,
290+ ConversionPatternRewriter &rewriter) const override {
279291 assert (op.getSpaceDim () == 1 && " Not implemented" );
280292 Location loc = op.getLoc ();
281293
@@ -299,18 +311,23 @@ class SparseCoIterateOpConverter
299311 assert (!needUniv && " Not implemented" );
300312 (void )needUniv;
301313
314+ SmallVector<Block *> newBlocks;
315+ DenseMap<Block *, Block *> newToOldBlockMap;
302316 for (Region ®ion : op.getCaseRegions ()) {
303317 // Do a one-shot type conversion on all region blocks, since the same
304318 // region might be used multiple time.
305319 Block *block = ®ion.getBlocks ().front ();
306- OneToNTypeMapping blockTypeMapping (block->getArgumentTypes ());
320+ TypeConverter::SignatureConversion blockTypeMapping (
321+ block->getArgumentTypes ().size ());
307322 if (failed (typeConverter->convertSignatureArgs (block->getArgumentTypes (),
308323 blockTypeMapping))) {
309324 return rewriter.notifyMatchFailure (
310325 op, " failed to convert coiterate region argurment types" );
311326 }
312327
313- rewriter.applySignatureConversion (block, blockTypeMapping);
328+ newBlocks.push_back (rewriter.applySignatureConversion (
329+ block, blockTypeMapping, getTypeConverter ()));
330+ newToOldBlockMap[newBlocks.back ()] = block;
314331 }
315332
316333 SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +360,7 @@ class SparseCoIterateOpConverter
343360
344361 // Generates a loop sequence, one loop per case.
345362 for (auto [r, caseBits] :
346- llvm::zip_equal (op. getCaseRegions () , op.getRegionDefinedSpaces ())) {
363+ llvm::zip_equal (newBlocks , op.getRegionDefinedSpaces ())) {
347364 assert (caseBits.count () > 0 && " Complement space not implemented" );
348365
349366 // Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +376,17 @@ class SparseCoIterateOpConverter
359376 // The subcases are never empty, it must contains at least the current
360377 // region itself.
361378 // TODO: these cases should be sorted.
362- SmallVector<Region *> subCases = op.getSubCasesOf (r.getRegionNumber ());
379+ SmallVector<Region *> subCases =
380+ op.getSubCasesOf (r->getParent ()->getRegionNumber ());
381+ SmallVector<Block *> newBlocks, oldBlocks;
382+ for (Region *r : subCases) {
383+ newBlocks.push_back (&r->front ());
384+ oldBlocks.push_back (newToOldBlockMap[newBlocks.back ()]);
385+ }
363386 assert (!subCases.empty ());
364387
365- ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd,
366- iters, subCases , userReduc);
388+ ValueRange res = genCoIterateBranchNest (
389+ rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks , userReduc);
367390
368391 SmallVector<Value> nextIterYields (res);
369392 // 2nd. foward the loop.
@@ -388,7 +411,7 @@ class SparseCoIterateOpConverter
388411 // This is a simple iteration loop.
389412 assert (caseBits.count () == 1 );
390413
391- Block *block = &r. getBlocks (). front () ;
414+ Block *block = r ;
392415 ValueRange curResult = genLoopWithIterator (
393416 rewriter, loc, validIters.front (), userReduc,
394417 /* bodyBuilder=*/
0 commit comments