Skip to content

Commit 91e7b9e

Browse files
author
Peiming Liu
committed
[mlir][sparse] annotate loops that are generated by loop emitter.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D138155
1 parent 0ec24e1 commit 91e7b9e

File tree

4 files changed

+39
-17
lines changed

4 files changed

+39
-17
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
9595
//===----------------------------------------------------------------------===//
9696

9797
SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors,
98+
StringAttr loopTag,
9899
bool hasOutput,
99100
bool isSparseOut)
100-
: hasOutput(hasOutput), isSparseOut(isSparseOut),
101+
: loopTag(loopTag), hasOutput(hasOutput), isSparseOut(isSparseOut),
101102
tensors(tensors.begin(), tensors.end()), dimTypes(tensors.size()),
102103
pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()),
103104
ptrBuffer(tensors.size()), idxBuffer(tensors.size()),
@@ -284,7 +285,7 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
284285
// NOTE: we can also prepares for next dim here in advance
285286
// Push the loop into stack
286287
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
287-
coord[tid][dim]);
288+
coord[tid][dim], loopTag);
288289
// Emit extra locals.
289290
emitExtraLocalsForTensorsAtDenseDims(builder, loc, extraTids, extraDims);
290291

@@ -386,7 +387,7 @@ Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
386387
// NOTE: we can also prepares for next dim here in advance
387388
}
388389
// Sets up the loop stack.
389-
loopStack.emplace_back(tids, dims, whileOp, min);
390+
loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
390391
assert(loopStack.size() == loopSeqStack.size());
391392

392393
// Emits extra locals

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ class SparseTensorLoopEmitter {
331331
/// tensor id (tid) used in related functions.
332332
/// If isSparseOut is set, loop emitter assume that the sparse output tensor
333333
/// is empty, and will always generate loops on it based on the dim sizes.
334-
explicit SparseTensorLoopEmitter(ValueRange tensors, bool hasOutput = false,
334+
explicit SparseTensorLoopEmitter(ValueRange tensors,
335+
StringAttr loopTag = nullptr,
336+
bool hasOutput = false,
335337
bool isSparseOut = false);
336338

337339
/// Starts a loop emitting session by generating all the buffers needed to
@@ -413,11 +415,20 @@ class SparseTensorLoopEmitter {
413415
};
414416
const std::vector<Value> &getValBuffer() const { return valBuffer; };
415417

418+
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
419+
return llvm::StringLiteral("Emitted from");
420+
}
421+
416422
private:
417423
struct LoopLevelInfo {
418424
LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
419-
Value iv)
420-
: tids(tids), dims(dims), loop(loop), iv(iv) {}
425+
Value iv, StringAttr loopTag)
426+
: tids(tids), dims(dims), loop(loop), iv(iv) {
427+
// Attached a special tag to loop emitter generated loop.
428+
if (loopTag)
429+
loop->setAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
430+
loopTag);
431+
}
421432
// TODO: maybe use a vector<pair> for tid and dim?
422433
// The set of tensors that the loop is operating on
423434
const llvm::SmallVector<size_t> tids;
@@ -485,8 +496,12 @@ class SparseTensorLoopEmitter {
485496
void exitCoIterationLoop(OpBuilder &builder, Location loc,
486497
MutableArrayRef<Value> reduc);
487498

488-
// Whether the loop emitter needs to treat the last tensor as the output
489-
// tensor.
499+
/// A optional string attribute that should be attached to the loop generated
500+
/// by loop emitter, it might help following passes to identify loops that
501+
/// operates on sparse tensors more easily.
502+
StringAttr loopTag;
503+
/// Whether the loop emitter needs to treat the last tensor as the output
504+
/// tensor.
490505
bool hasOutput;
491506
bool isSparseOut;
492507
/// Input and (optional) output tensors.

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,9 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
789789
auto enc = getSparseTensorEncoding(rtp);
790790

791791
// 1. Generates loop for the sparse input.
792-
SparseTensorLoopEmitter loopEmitter(ValueRange{input});
792+
SparseTensorLoopEmitter loopEmitter(
793+
ValueRange{input},
794+
StringAttr::get(getContext(), ForeachOp::getOperationName()));
793795
loopEmitter.initializeLoopEmit(rewriter, loc);
794796
for (int64_t i = 0; i < rank; i++) {
795797
// TODO: provide utility function for loop sequences that only contains

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,15 @@ enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
5353

5454
// Code generation.
5555
struct CodeGen {
56-
CodeGen(SparsificationOptions o, ValueRange tensors, unsigned numTensors,
57-
unsigned numLoops, OpOperand *op, unsigned nest,
56+
CodeGen(SparsificationOptions o, MLIRContext *context, ValueRange tensors,
57+
unsigned numTensors, unsigned numLoops, OpOperand *op, unsigned nest,
5858
std::vector<unsigned> &ts)
59-
: options(o), loopEmitter(tensors, /*hasOutput=*/true,
60-
/*isSparseOut=*/op != nullptr),
59+
: options(o),
60+
loopEmitter(
61+
tensors,
62+
StringAttr::get(context, linalg::GenericOp::getOperationName()),
63+
/*hasOutput=*/true,
64+
/*isSparseOut=*/op != nullptr),
6165
sparseOut(op), outerParNest(nest), topSort(ts) {
6266
if (op)
6367
insChain = op->get();
@@ -670,8 +674,8 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
670674
// Select operation insertion.
671675
Value insChain = codegen.insChain;
672676
assert(insChain);
673-
scf::IfOp ifOp = builder.create<scf::IfOp>(
674-
loc, insChain.getType(), rhs, /*else=*/true);
677+
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, insChain.getType(), rhs,
678+
/*else=*/true);
675679
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
676680
// Existing value was preserved to be used here.
677681
assert(merger.exp(exp).val);
@@ -1372,8 +1376,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
13721376
tensors.push_back(t.get());
13731377

13741378
// Recursively generates code if admissible.
1375-
CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut,
1376-
outerParNest, topSort);
1379+
CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops,
1380+
sparseOut, outerParNest, topSort);
13771381
genBuffers(merger, codegen, rewriter, op);
13781382
genStmt(merger, codegen, rewriter, op, exp, 0);
13791383
genResult(merger, codegen, rewriter, op);

0 commit comments

Comments
 (0)