@@ -1078,7 +1078,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10781078 }
10791079 };
10801080
1081- auto handleNonVectorOutputs = [&](Value newResult) {
1081+ auto handleNonVectorOutputs = [&](Value newResult,
1082+ Type originalResultType) {
10821083 // Handle the case when op results are not vectorized or have smaller
10831084 // vector size, extract the elements from the vector.
10841085 for (auto [i, result] : llvm::enumerate (node->ops )) {
@@ -1087,7 +1088,16 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10871088 if (getNodeForOp (useOwner))
10881089 continue ;
10891090
1090- Value elem = rewriter.create <vector::ExtractOp>(loc, newResult, i);
1091+ int64_t offset = i * node->elementsCount ;
1092+ Value elem;
1093+
1094+ if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
1095+ elem = rewriter.create <vector::ExtractStridedSliceOp>(
1096+ loc, newResult, offset, vecType.getNumElements (), 1 );
1097+ } else {
1098+ elem = rewriter.create <vector::ExtractOp>(loc, newResult, offset);
1099+ }
1100+
10911101 use.set (elem);
10921102 }
10931103 }
@@ -1109,8 +1119,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
11091119 VectorType::get (numElements, getElementTypeAndCount (op)->first );
11101120 Value result = rewriter.create <vector::LoadOp>(loc, vecType, getBase (op),
11111121 getIndices (op));
1112- mapping.map (op->getResult (0 ), result);
1113- handleNonVectorOutputs (result);
1122+ Value originalResult = op->getResult (0 );
1123+ mapping.map (originalResult, result);
1124+ handleNonVectorOutputs (result, originalResult.getType ());
11141125 } else if (maybeWriteOp (op)) {
11151126 handleNonVectorInputs (getValueToStore (op));
11161127 Value val = mapping.lookupOrDefault (getValueToStore (op));
@@ -1133,7 +1144,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
11331144 newOp->getResult (0 ).setType (resVectorType);
11341145
11351146 mapping.map (op->getResults (), newOp->getResults ());
1136- handleNonVectorOutputs (newOp->getResult (0 ));
1147+ handleNonVectorOutputs (newOp->getResult (0 ), op-> getResultTypes (). front () );
11371148 } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
11381149 // We alredy verified index is valid during graph construction, so
11391150 // do need to check `getExtractIndex` result.
0 commit comments