@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
9595 ValueRange getLvlBuffers () const override { return {}; }
9696
9797 ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
98- ValueRange parentPos) const override {
98+ ValueRange parentPos, Value inPadZone ) const override {
9999 assert (parentPos.size () == 1 && " Dense level can not be non-unique." );
100+ assert (!inPadZone && " Not implemented" );
100101 Value p = parentPos.front ();
101102 Value posLo = MULI (p, lvlSize);
102103 return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
115116 ValueRange getLvlBuffers () const override { return {}; }
116117
117118 ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange,
118- ValueRange parentPos) const override {
119+ ValueRange parentPos, Value inPadZone) const override {
120+ assert (!inPadZone && " Not implemented" );
119121 assert (parentPos.size () == 1 && " Dense level can not be non-unique." );
120122 // No need to linearize the position for non-annotated tensors.
121123 return {C_IDX (0 ), lvlSize};
@@ -129,18 +131,42 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
129131 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
130132
131133 ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
132- ValueRange parentPos) const override {
134+ ValueRange parentPos, Value inPadZone ) const override {
133135
134136 assert (parentPos.size () == 1 &&
135137 " compressed level must be the first non-unique level." );
136- Value p = parentPos.front ();
137138
138- SmallVector<Value> memCrd (batchPrefix);
139- memCrd.push_back (p);
140- Value pLo = genIndexLoad (b, l, getPosBuf (), memCrd);
141- memCrd.back () = ADDI (p, C_IDX (1 ));
142- Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
143- return {pLo, pHi};
139+ auto loadRange = [&b, l, parentPos, batchPrefix, this ]() -> ValuePair {
140+ Value p = parentPos.front ();
141+ SmallVector<Value> memCrd (batchPrefix);
142+ memCrd.push_back (p);
143+ Value pLo = genIndexLoad (b, l, getPosBuf (), memCrd);
144+ memCrd.back () = ADDI (p, C_IDX (1 ));
145+ Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
146+ return {pLo, pHi};
147+ };
148+
149+ if (inPadZone == nullptr )
150+ return loadRange ();
151+
152+ SmallVector<Type, 2 > types{b.getIndexType (), b.getIndexType ()};
153+ scf::IfOp posRangeIf = b.create <scf::IfOp>(l, types, inPadZone, true );
154+ // True branch, returns a "fake" empty range [0, 0) if parent
155+ // iterator is in pad zone.
156+ b.setInsertionPointToStart (posRangeIf.thenBlock ());
157+
158+ SmallVector<Value, 2 > emptyRange{C_IDX (0 ), C_IDX (0 )};
159+ b.create <scf::YieldOp>(l, emptyRange);
160+
161+ // False branch, returns the actual range.
162+ b.setInsertionPointToStart (posRangeIf.elseBlock ());
163+ auto [pLo, pHi] = loadRange ();
164+ SmallVector<Value, 2 > loadedRange{pLo, pHi};
165+ b.create <scf::YieldOp>(l, loadedRange);
166+
167+ b.setInsertionPointAfter (posRangeIf);
168+ ValueRange posRange = posRangeIf.getResults ();
169+ return {posRange.front (), posRange.back ()};
144170 }
145171};
146172
@@ -151,9 +177,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
151177 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
152178
153179 ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
154- ValueRange parentPos) const override {
180+ ValueRange parentPos, Value inPadZone ) const override {
155181 assert (parentPos.size () == 1 &&
156182 " loose-compressed level must be the first non-unique level." );
183+ assert (!inPadZone && " Not implemented" );
157184 SmallVector<Value> memCrd (batchPrefix);
158185 Value p = parentPos.front ();
159186 p = MULI (p, C_IDX (2 ));
@@ -172,8 +199,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
172199 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
173200
174201 ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
175- ValueRange parentPos) const override {
202+ ValueRange parentPos, Value inPadZone ) const override {
176203 assert (parentPos.size () == 1 || parentPos.size () == 2 );
204+ assert (!inPadZone && " Not implemented" );
177205 Value p = parentPos.front ();
178206 Value segHi = parentPos.size () == 2 ? parentPos.back () : nullptr ;
179207
@@ -191,9 +219,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
191219 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
192220
193221 ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
194- ValueRange parentPos) const override {
222+ ValueRange parentPos, Value inPadZone ) const override {
195223 assert (parentPos.size () == 1 && isUnique () &&
196224 " n:m level can not be non-unique." );
225+ assert (!inPadZone && " Not implemented" );
197226 // Each n:m blk has exactly n specified elements.
198227 auto n = getN (lt);
199228 Value posLo = MULI (parentPos.front (), C_IDX (n));
@@ -325,23 +354,7 @@ class TrivialIterator : public ConcreteIterator {
325354 };
326355
327356 void genInitImpl (OpBuilder &b, Location l,
328- const SparseIterator *parent) override {
329-
330- if (isBatchIterator () && batchCrds.size () <= stl.lvl )
331- batchCrds.resize (stl.lvl + 1 , nullptr );
332-
333- Value c0 = C_IDX (0 );
334- ValueRange pPos = c0;
335- // If the parent iterator is a batch iterator, we also start from 0 (but
336- // on a different batch).
337- if (parent && !parent->isBatchIterator ())
338- pPos = parent->getCurPosition ();
339-
340- ValueRange batchPrefix = parent ? parent->getBatchCrds () : ValueRange{};
341- std::tie (posLo, posHi) = stl.peekRangeAt (b, l, batchPrefix, pPos);
342- // Seek to the lowest position.
343- seek (posLo);
344- }
357+ const SparseIterator *parent) override ;
345358
346359 ValuePair genForCond (OpBuilder &b, Location l) override {
347360 if (randomAccessible ())
@@ -465,15 +478,17 @@ class DedupIterator : public ConcreteIterator {
465478// A util base-iterator that delegates all methods to the wrapped iterator.
466479class SimpleWrapIterator : public SparseIterator {
467480public:
468- SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
481+ SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
482+ unsigned extraCursorVal = 0 )
483+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
470484
471485 SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
472486 return wrap->getCursorValTypes (b);
473487 }
474488 bool isBatchIterator () const override { return wrap->isBatchIterator (); }
475489 bool randomAccessible () const override { return wrap->randomAccessible (); };
476490 bool iteratableByFor () const override { return wrap->iteratableByFor (); };
491+
477492 SmallVector<Value> serialize () const override { return wrap->serialize (); };
478493 void deserialize (ValueRange vs) override { wrap->deserialize (vs); };
479494 ValueRange getCurPosition () const override { return wrap->getCurPosition (); }
@@ -586,10 +601,9 @@ class PadIterator : public SimpleWrapIterator {
586601public:
587602 PadIterator (std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588603 Value padHigh)
589- : SimpleWrapIterator(std::move(wrap), IterKind::kPad ), padLow(padLow),
590- padHigh (padHigh) {
591- assert (!randomAccessible () && " Not implemented." );
592- }
604+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad ,
605+ wrap->randomAccessible () ? 1 : 0),
606+ padLow(padLow), padHigh(padHigh) {}
593607
594608 // For LLVM-style RTTI.
595609 static bool classof (const SparseIterator *from) {
@@ -600,6 +614,26 @@ class PadIterator : public SimpleWrapIterator {
600614 return std::string (" pad<" ) + wrap->getDebugInterfacePrefix () + " >" ;
601615 }
602616
617+ // Returns a pair of values for *upper*, *lower* bound respectively.
618+ ValuePair genForCond (OpBuilder &b, Location l) override {
619+ if (randomAccessible ())
620+ return {getCrd (), upperBound (b, l)};
621+ return wrap->genForCond (b, l);
622+ }
623+
624+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
625+ // values used by the wrapped iterator.
626+ ValueRange getCurPosition () const override { return getCursor (); }
627+
628+ SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
629+ SmallVector<Type> ret = wrap->getCursorValTypes (b);
630+ // Need an extra boolean value `inPadZone` for padded dense iterator.
631+ if (randomAccessible ())
632+ ret.push_back (b.getI1Type ());
633+
634+ return ret;
635+ }
636+
603637 // The upper bound after padding becomes `size + padLow + padHigh`.
604638 Value upperBound (OpBuilder &b, Location l) const override {
605639 return ADDI (ADDI (wrap->upperBound (b, l), padLow), padHigh);
@@ -613,6 +647,14 @@ class PadIterator : public SimpleWrapIterator {
613647
614648 void locateImpl (OpBuilder &b, Location l, Value crd) override {
615649 assert (randomAccessible ());
650+ wrap->locate (b, l, SUBI (crd, padLow));
651+
652+ // inPadZone = crd < padLow || crd >= size + padLow.
653+ Value inPadLow = CMPI (ult, crd, padLow);
654+ Value inPadHigh = CMPI (uge, crd, ADDI (wrap->upperBound (b, l), padLow));
655+ getMutCursorVals ().back () = ORI (inPadLow, inPadHigh);
656+
657+ updateCrd (crd);
616658 }
617659
618660 Value padLow, padHigh;
@@ -1227,6 +1269,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
12271269 return p->inflateSubSectTree (b, l, reduc, visitDenseSubSect);
12281270}
12291271
1272+ void TrivialIterator::genInitImpl (OpBuilder &b, Location l,
1273+ const SparseIterator *parent) {
1274+
1275+ if (isBatchIterator () && batchCrds.size () <= stl.lvl )
1276+ batchCrds.resize (stl.lvl + 1 , nullptr );
1277+
1278+ Value c0 = C_IDX (0 );
1279+ ValueRange pPos = c0;
1280+ Value inPadZone = nullptr ;
1281+ // If the parent iterator is a batch iterator, we also start from 0 (but
1282+ // on a different batch).
1283+ if (parent && !parent->isBatchIterator ()) {
1284+ pPos = parent->getCurPosition ();
1285+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible ()) {
1286+ // A padded dense iterator create "sparse" padded zone, which need to be
1287+ // handled specially.
1288+ inPadZone = pPos.back ();
1289+ pPos = pPos.drop_back ();
1290+ }
1291+ }
1292+
1293+ ValueRange batchPrefix = parent ? parent->getBatchCrds () : ValueRange{};
1294+ std::tie (posLo, posHi) = stl.peekRangeAt (b, l, batchPrefix, pPos, inPadZone);
1295+ // Seek to the lowest position.
1296+ seek (posLo);
1297+ }
1298+
12301299void NonEmptySubSectIterator::genInitImpl (OpBuilder &b, Location l,
12311300 const SparseIterator *) {
12321301 Value c0 = C_IDX (0 );
0 commit comments