@@ -34,6 +34,20 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
3434 return success ();
3535}
3636
37+ static std::optional<LogicalResult>
38+ convertIteratorType (IteratorType itTp, SmallVectorImpl<Type> &fields) {
39+ // The actually Iterator Values (that are updated every iteration).
40+ auto idxTp = IndexType::get (itTp.getContext ());
41+ // TODO: handle batch dimension.
42+ assert (itTp.getEncoding ().getBatchLvlRank () == 0 );
43+ if (!itTp.isUnique ()) {
44+ // Segment high for non-unique iterator.
45+ fields.push_back (idxTp);
46+ }
47+ fields.push_back (idxTp);
48+ return success ();
49+ }
50+
3751namespace {
3852
3953// / Sparse codegen rule for number of entries operator.
@@ -57,10 +71,114 @@ class ExtractIterSpaceConverter
5771 }
5872};
5973
74+ class SparseIterateOpConverter : public OneToNOpConversionPattern <IterateOp> {
75+ public:
76+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
77+ LogicalResult
78+ matchAndRewrite (IterateOp op, OpAdaptor adaptor,
79+ OneToNPatternRewriter &rewriter) const override {
80+ if (!op.getCrdUsedLvls ().empty ())
81+ return rewriter.notifyMatchFailure (
82+ op, " non-empty coordinates list not implemented." );
83+
84+ Location loc = op.getLoc ();
85+
86+ auto iterSpace = SparseIterationSpace::fromValues (
87+ op.getIterSpace ().getType (), adaptor.getIterSpace (), 0 );
88+
89+ std::unique_ptr<SparseIterator> it =
90+ iterSpace.extractIterator (rewriter, loc);
91+
92+ if (it->iteratableByFor ()) {
93+ auto [lo, hi] = it->genForCond (rewriter, loc);
94+ Value step = constantIndex (rewriter, loc, 1 );
95+ SmallVector<Value> ivs;
96+ for (ValueRange inits : adaptor.getInitArgs ())
97+ llvm::append_range (ivs, inits);
98+ scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, ivs);
99+
100+ Block *loopBody = op.getBody ();
101+ OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
102+ if (failed (typeConverter->convertSignatureArgs (
103+ loopBody->getArgumentTypes (), bodyTypeMapping)))
104+ return failure ();
105+ rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
106+
107+ forOp.getBody ()->erase ();
108+ Region &dstRegion = forOp.getRegion ();
109+ rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
110+
111+ auto yieldOp =
112+ llvm::cast<sparse_tensor::YieldOp>(forOp.getBody ()->getTerminator ());
113+
114+ rewriter.setInsertionPointToEnd (forOp.getBody ());
115+ // replace sparse_tensor.yield with scf.yield.
116+ rewriter.create <scf::YieldOp>(loc, yieldOp.getResults ());
117+ yieldOp.erase ();
118+
119+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
120+ rewriter.replaceOp (op, forOp.getResults (), resultMapping);
121+ } else {
122+ SmallVector<Value> ivs;
123+ llvm::append_range (ivs, it->getCursor ());
124+ for (ValueRange inits : adaptor.getInitArgs ())
125+ llvm::append_range (ivs, inits);
126+
127+ assert (llvm::all_of (ivs, [](Value v) { return v != nullptr ; }));
128+
129+ TypeRange types = ValueRange (ivs).getTypes ();
130+ auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
131+ SmallVector<Location> l (types.size (), op.getIterator ().getLoc ());
132+
133+ // Generates loop conditions.
134+ Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
135+ rewriter.setInsertionPointToStart (before);
136+ ValueRange bArgs = before->getArguments ();
137+ auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
138+ assert (remArgs.size () == adaptor.getInitArgs ().size ());
139+ rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
140+
141+ // Generates loop body.
142+ Block *loopBody = op.getBody ();
143+ OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
144+ if (failed (typeConverter->convertSignatureArgs (
145+ loopBody->getArgumentTypes (), bodyTypeMapping)))
146+ return failure ();
147+ rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
148+
149+ Region &dstRegion = whileOp.getAfter ();
150+ // TODO: handle uses of coordinate!
151+ rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
152+ ValueRange aArgs = whileOp.getAfterArguments ();
153+ auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
154+ whileOp.getAfterBody ()->getTerminator ());
155+
156+ rewriter.setInsertionPointToEnd (whileOp.getAfterBody ());
157+
158+ aArgs = it->linkNewScope (aArgs);
159+ ValueRange nx = it->forward (rewriter, loc);
160+ SmallVector<Value> yields;
161+ llvm::append_range (yields, nx);
162+ llvm::append_range (yields, yieldOp.getResults ());
163+
164+ // replace sparse_tensor.yield with scf.yield.
165+ yieldOp->erase ();
166+ rewriter.create <scf::YieldOp>(loc, yields);
167+
168+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
169+ rewriter.replaceOp (
170+ op, whileOp.getResults ().drop_front (it->getCursor ().size ()),
171+ resultMapping);
172+ }
173+ return success ();
174+ }
175+ };
176+
60177} // namespace
61178
62179mlir::SparseIterationTypeConverter::SparseIterationTypeConverter () {
63180 addConversion ([](Type type) { return type; });
181+ addConversion (convertIteratorType);
64182 addConversion (convertIterSpaceType);
65183
66184 addSourceMaterialization ([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +192,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
74192
75193void mlir::populateLowerSparseIterationToSCFPatterns (
76194 TypeConverter &converter, RewritePatternSet &patterns) {
77- patterns.add <ExtractIterSpaceConverter>(converter, patterns.getContext ());
195+ patterns.add <ExtractIterSpaceConverter, SparseIterateOpConverter>(
196+ converter, patterns.getContext ());
78197}
0 commit comments