Skip to content

Commit 31ac319

Browse files
committed
comments
1 parent 5a5d411 commit 31ac319

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ getElementTypeAndCount(Operation *op) {
8989
return getVectorElementTypeAndCount(loadOp.getVectorType());
9090
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
9191
return getVectorElementTypeAndCount(storeOp.getVectorType());
92+
9293
return std::nullopt;
9394
}
9495

@@ -147,7 +148,8 @@ static Value getBase(Operation *op) {
147148
return loadOp.getBase();
148149
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
149150
return storeOp.getBase();
150-
return {};
151+
152+
llvm_unreachable("unsupported op");
151153
}
152154

153155
static Value getValueToStore(Operation *op) {
@@ -156,7 +158,8 @@ static Value getValueToStore(Operation *op) {
156158
return storeOp.getValueToStore();
157159
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
158160
return storeOp.getValueToStore();
159-
return {};
161+
162+
llvm_unreachable("unsupported op");
160163
}
161164

162165
static bool isContiguousLastDim(Value val) {
@@ -182,7 +185,8 @@ static ValueRange getIndices(Operation *op) {
182185
return loadOp.getIndices();
183186
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
184187
return storeOp.getIndices();
185-
return {};
188+
189+
llvm_unreachable("unsupported op");
186190
}
187191

188192
static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
@@ -206,7 +210,7 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
206210
return diffConst && diffConst.getValue() == offset;
207211
}
208212

209-
/// Check if two indices are consecutive, i.e index1 + 1 == index2.
213+
/// Check if two indices are consecutive, i.e index1 + offset == index2.
210214
static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
211215
if (auto c1 = getConstantIntValue(idx1)) {
212216
if (auto c2 = getConstantIntValue(idx2))
@@ -232,7 +236,7 @@ static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
232236
}
233237

234238
/// Check if two ranges of indices are consecutive, i.e fastest index differs
235-
/// by 1 and all other indices are the same.
239+
/// by `offset` and all other indices are the same.
236240
static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2,
237241
int64_t offset) {
238242
if (idx1.empty() || idx1.size() != idx2.size())
@@ -272,6 +276,7 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
272276
if (typeAndCount1 != typeAndCount2)
273277
return false;
274278

279+
// For now we are only merging ops with same elements count.
275280
return isAdjacentIndices(getIndices(op1), getIndices(op2),
276281
typeAndCount1->second);
277282
}
@@ -332,6 +337,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
332337
return result;
333338
}
334339

340+
/// Check if an operation is vectorizable.
341+
/// If `expectedElementsCount` is provided, check if original op had the
342+
/// specified number of elements.
335343
static bool
336344
isVectorizable(Operation *op,
337345
std::optional<int64_t> expectedElementsCount = std::nullopt) {
@@ -362,7 +370,8 @@ isVectorizable(Operation *op,
362370
return true;
363371
}
364372

365-
/// Get the next operation in the block, assuming `op` is not a terminator.
373+
/// Get the next operation in the block, assuming `op` is not a terminator/last
374+
/// operation in the block.
366375
static Operation *nextOp(Operation *op) {
367376
assert(op && "null op");
368377
auto it = op->getIterator();
@@ -390,6 +399,9 @@ struct SLPGraphNode {
390399
return ops.front();
391400
}
392401

402+
/// Get the suitable insertion point for the new vectorized op.
403+
/// This method is trying to take into account operands insertions points too
404+
/// to satisfy dominance relations.
393405
Operation *getInsertionPoint() {
394406
assert(!ops.empty() && "empty node");
395407
if (insertionPoint)
@@ -1038,6 +1050,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10381050
Location loc = op->getLoc();
10391051

10401052
auto handleNonVectorInputs = [&](ValueRange operands) {
1053+
// Handle the case when op operands are not vectorized or have smaller
1054+
// vector size, construct the vector from the scalar operands using
1055+
// FromElementsOp.
10411056
for (auto [i, operand] : llvm::enumerate(operands)) {
10421057
if (getNodeForOp(operand.getDefiningOp()))
10431058
continue;
@@ -1064,6 +1079,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10641079
};
10651080

10661081
auto handleNonVectorOutputs = [&](Value newResult) {
1082+
// Handle the case when op results are not vectorized or have smaller
1083+
// vector size, extract the elements from the vector.
10671084
for (auto [i, result] : llvm::enumerate(node->ops)) {
10681085
for (OpOperand &use : result->getUses()) {
10691086
Operation *useOwner = use.getOwner();
@@ -1077,6 +1094,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10771094
};
10781095

10791096
auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value {
1097+
// Handle vector size misamatch between 2 vectorized nodes.
10801098
auto srcType = cast<VectorType>(arg.getType());
10811099
assert(srcType.getRank() == 1);
10821100
if (srcType.getDimSize(0) == numElements)
@@ -1117,7 +1135,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
11171135
mapping.map(op->getResults(), newOp->getResults());
11181136
handleNonVectorOutputs(newOp->getResult(0));
11191137
} else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
1120-
// We alredy verified index is valid during graph construction.
1138+
// We alredy verified index is valid during graph construction, so
1139+
// do need to check `getExtractIndex` result.
11211140
int64_t offset = *getExtractIndex(extract);
11221141
Value val = handleVecSizeMismatch(extract.getVector(), offset);
11231142
mapping.map(extract.getResult(), val);

0 commit comments

Comments
 (0)