@@ -89,6 +89,7 @@ getElementTypeAndCount(Operation *op) {
8989 return getVectorElementTypeAndCount (loadOp.getVectorType ());
9090 if (auto storeOp = dyn_cast<vector::StoreOp>(op))
9191 return getVectorElementTypeAndCount (storeOp.getVectorType ());
92+
9293 return std::nullopt ;
9394}
9495
@@ -147,7 +148,8 @@ static Value getBase(Operation *op) {
147148 return loadOp.getBase ();
148149 if (auto storeOp = dyn_cast<vector::StoreOp>(op))
149150 return storeOp.getBase ();
150- return {};
151+
152+ llvm_unreachable (" unsupported op" );
151153}
152154
153155static Value getValueToStore (Operation *op) {
@@ -156,7 +158,8 @@ static Value getValueToStore(Operation *op) {
156158 return storeOp.getValueToStore ();
157159 if (auto storeOp = dyn_cast<vector::StoreOp>(op))
158160 return storeOp.getValueToStore ();
159- return {};
161+
162+ llvm_unreachable (" unsupported op" );
160163}
161164
162165static bool isContiguousLastDim (Value val) {
@@ -182,7 +185,8 @@ static ValueRange getIndices(Operation *op) {
182185 return loadOp.getIndices ();
183186 if (auto storeOp = dyn_cast<vector::StoreOp>(op))
184187 return storeOp.getIndices ();
185- return {};
188+
189+ llvm_unreachable (" unsupported op" );
186190}
187191
188192static bool isAdjacentAffineMapIndices (Value idx1, Value idx2, int64_t offset) {
@@ -206,7 +210,7 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
206210 return diffConst && diffConst.getValue () == offset;
207211}
208212
209- // / Check if two indices are consecutive, i.e index1 + 1 == index2.
213+ // / Check if two indices are consecutive, i.e index1 + offset == index2.
210214static bool isAdjacentIndices (Value idx1, Value idx2, int64_t offset) {
211215 if (auto c1 = getConstantIntValue (idx1)) {
212216 if (auto c2 = getConstantIntValue (idx2))
@@ -232,7 +236,7 @@ static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
232236}
233237
234238// / Check if two ranges of indices are consecutive, i.e fastest index differs
235- // / by 1 and all other indices are the same.
239+ // / by `offset` and all other indices are the same.
236240static bool isAdjacentIndices (ValueRange idx1, ValueRange idx2,
237241 int64_t offset) {
238242 if (idx1.empty () || idx1.size () != idx2.size ())
@@ -272,6 +276,7 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
272276 if (typeAndCount1 != typeAndCount2)
273277 return false ;
274278
279+ // For now we are only merging ops with same elements count.
275280 return isAdjacentIndices (getIndices (op1), getIndices (op2),
276281 typeAndCount1->second );
277282}
@@ -332,6 +337,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
332337 return result;
333338}
334339
340+ // / Check if an operation is vectorizable.
341+ // / If `expectedElementsCount` is provided, check if original op had the
342+ // / specified number of elements.
335343static bool
336344isVectorizable (Operation *op,
337345 std::optional<int64_t > expectedElementsCount = std::nullopt ) {
@@ -362,7 +370,8 @@ isVectorizable(Operation *op,
362370 return true ;
363371}
364372
365- // / Get the next operation in the block, assuming `op` is not a terminator.
373+ // / Get the next operation in the block, assuming `op` is not a terminator/last
374+ // / operation in the block.
366375static Operation *nextOp (Operation *op) {
367376 assert (op && " null op" );
368377 auto it = op->getIterator ();
@@ -390,6 +399,9 @@ struct SLPGraphNode {
390399 return ops.front ();
391400 }
392401
402+ // / Get the suitable insertion point for the new vectorized op.
403+ // / This method is trying to take into account operands insertions points too
404+ // / to satisfy dominance relations.
393405 Operation *getInsertionPoint () {
394406 assert (!ops.empty () && " empty node" );
395407 if (insertionPoint)
@@ -1038,6 +1050,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10381050 Location loc = op->getLoc ();
10391051
10401052 auto handleNonVectorInputs = [&](ValueRange operands) {
1053+ // Handle the case when op operands are not vectorized or have smaller
1054+ // vector size, construct the vector from the scalar operands using
1055+ // FromElementsOp.
10411056 for (auto [i, operand] : llvm::enumerate (operands)) {
10421057 if (getNodeForOp (operand.getDefiningOp ()))
10431058 continue ;
@@ -1064,6 +1079,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10641079 };
10651080
10661081 auto handleNonVectorOutputs = [&](Value newResult) {
1082+ // Handle the case when op results are not vectorized or have smaller
1083+ // vector size, extract the elements from the vector.
10671084 for (auto [i, result] : llvm::enumerate (node->ops )) {
10681085 for (OpOperand &use : result->getUses ()) {
10691086 Operation *useOwner = use.getOwner ();
@@ -1077,6 +1094,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10771094 };
10781095
10791096 auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0 ) -> Value {
1097+ // Handle vector size misamatch between 2 vectorized nodes.
10801098 auto srcType = cast<VectorType>(arg.getType ());
10811099 assert (srcType.getRank () == 1 );
10821100 if (srcType.getDimSize (0 ) == numElements)
@@ -1117,7 +1135,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
11171135 mapping.map (op->getResults (), newOp->getResults ());
11181136 handleNonVectorOutputs (newOp->getResult (0 ));
11191137 } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
1120- // We alredy verified index is valid during graph construction.
1138+ // We alredy verified index is valid during graph construction, so
1139+ // do need to check `getExtractIndex` result.
11211140 int64_t offset = *getExtractIndex (extract);
11221141 Value val = handleVecSizeMismatch (extract.getVector (), offset);
11231142 mapping.map (extract.getResult (), val);
0 commit comments