Skip to content

Commit 8a6d73f

Browse files
rupprechtaokblast
authored andcommitted
[mlir][sparse] Include sparse emit strategy in wrapping iterator (llvm#165611)
When we create a `SparseIterator`, we sometimes wrap it in a `FilterIterator`, which delegates _some_ calls to the underlying `SparseIterator`. After construction, e.g. in `makeNonEmptySubSectIterator()`, we call `setSparseEmitStrategy()`. This sets the strategy only in one of the filters -- if we call `setSparseEmitStrategy()` immediately after creating the `SparseIterator`, then the wrapped `SparseIterator` will have the right strategy, and the `FilterIterator` strategy will be unintialized; if we call `setSparseEmitStrategy()` after wrapping the iterator in `FilterIterator`, then the opposite happens. If we make `setSparseEmitStrategy()` a virtual method so that it's included in the `FilterIterator` pattern, and then do all reads of `emitStrategy` via a virtual method as well, it's pretty simple to ensure that the value of `strategy` is being set consistently and correctly. Without this, the UB of strategy being uninitialized manifests as a sporadic test failure in mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir, when run downstream with the right flags (e.g. asan + assertions off). The test sometimes fails with `ne_sub<trivial<dense[0,1]>>.begin' op created with unregistered dialect`. It can also be directly observed w/ msan that this uninitialized read is the cause of that issue, but msan causes other problems w/ this test.
1 parent b46de8d commit 8a6d73f

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,14 @@ class SimpleWrapIterator : public SparseIterator {
504504
unsigned extraCursorVal = 0)
505505
: SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506506

507+
void setSparseEmitStrategy(SparseEmitStrategy strategy) override {
508+
wrap->setSparseEmitStrategy(strategy);
509+
}
510+
511+
SparseEmitStrategy getSparseEmitStrategy() const override {
512+
return wrap->getSparseEmitStrategy();
513+
}
514+
507515
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
508516
return wrap->getCursorValTypes(b);
509517
}
@@ -979,7 +987,7 @@ class SubSectIterator : public SparseIterator {
979987

980988
void SparseIterator::genInit(OpBuilder &b, Location l,
981989
const SparseIterator *p) {
982-
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
990+
if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
983991
std::string prefix = getDebugInterfacePrefix();
984992
Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
985993
getCursorValTypes(b));
@@ -994,7 +1002,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
9941002
}
9951003

9961004
Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
997-
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1005+
if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
9981006
std::string prefix = getDebugInterfacePrefix();
9991007
Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
10001008
getCursor(), b.getI1Type());
@@ -1005,7 +1013,7 @@ Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
10051013
}
10061014

10071015
void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
1008-
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1016+
if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
10091017
std::string prefix = getDebugInterfacePrefix();
10101018
SmallVector<Value> args = getCursor();
10111019
args.push_back(crd);
@@ -1019,7 +1027,7 @@ void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
10191027
}
10201028

10211029
Value SparseIterator::deref(OpBuilder &b, Location l) {
1022-
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1030+
if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
10231031
std::string prefix = getDebugInterfacePrefix();
10241032
SmallVector<Value> args = getCursor();
10251033
Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
@@ -1032,7 +1040,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) {
10321040

10331041
ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
10341042
assert(!randomAccessible());
1035-
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
1043+
if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) {
10361044
std::string prefix = getDebugInterfacePrefix();
10371045
Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
10381046
getCursor(), getCursorValTypes(b));

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,14 @@ class SparseIterator {
177177
public:
178178
virtual ~SparseIterator() = default;
179179

180-
void setSparseEmitStrategy(SparseEmitStrategy strategy) {
180+
virtual void setSparseEmitStrategy(SparseEmitStrategy strategy) {
181181
emitStrategy = strategy;
182182
}
183183

184+
virtual SparseEmitStrategy getSparseEmitStrategy() const {
185+
return emitStrategy;
186+
}
187+
184188
virtual std::string getDebugInterfacePrefix() const = 0;
185189
virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
186190

0 commit comments

Comments
 (0)