Skip to content

Commit b79e3f3

Browse files
committed
vector outputs handling
1 parent 31ac319 commit b79e3f3

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10781078
}
10791079
};
10801080

1081-
auto handleNonVectorOutputs = [&](Value newResult) {
1081+
auto handleNonVectorOutputs = [&](Value newResult,
1082+
Type originalResultType) {
10821083
// Handle the case when op results are not vectorized or have smaller
10831084
// vector size, extract the elements from the vector.
10841085
for (auto [i, result] : llvm::enumerate(node->ops)) {
@@ -1087,7 +1088,16 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10871088
if (getNodeForOp(useOwner))
10881089
continue;
10891090

1090-
Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
1091+
int64_t offset = i * node->elementsCount;
1092+
Value elem;
1093+
1094+
if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
1095+
elem = rewriter.create<vector::ExtractStridedSliceOp>(
1096+
loc, newResult, offset, vecType.getNumElements(), 1);
1097+
} else {
1098+
elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
1099+
}
1100+
10911101
use.set(elem);
10921102
}
10931103
}
@@ -1109,8 +1119,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
11091119
VectorType::get(numElements, getElementTypeAndCount(op)->first);
11101120
Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
11111121
getIndices(op));
1112-
mapping.map(op->getResult(0), result);
1113-
handleNonVectorOutputs(result);
1122+
Value originalResult = op->getResult(0);
1123+
mapping.map(originalResult, result);
1124+
handleNonVectorOutputs(result, originalResult.getType());
11141125
} else if (maybeWriteOp(op)) {
11151126
handleNonVectorInputs(getValueToStore(op));
11161127
Value val = mapping.lookupOrDefault(getValueToStore(op));
@@ -1133,7 +1144,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
11331144
newOp->getResult(0).setType(resVectorType);
11341145

11351146
mapping.map(op->getResults(), newOp->getResults());
1136-
handleNonVectorOutputs(newOp->getResult(0));
1147+
handleNonVectorOutputs(newOp->getResult(0), op->getResultTypes().front());
11371148
} else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
11381149
// We alredy verified index is valid during graph construction, so
11391150
// do need to check `getExtractIndex` result.

mlir/test/Dialect/Vector/slp-vectorize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,32 @@ func.func @read_read_add_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
672672
}
673673

674674

675+
// CHECK-LABEL: func @read_read_add_add_vec
676+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
677+
func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
678+
(vector<2xi32>, vector<2xi32>){
679+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
680+
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
681+
// CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
682+
// CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<4xi32>
683+
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
684+
// CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
685+
// CHECK: return %[[V3]], %[[V4]] : vector<2xi32>, vector<2xi32>
686+
%c0 = arith.constant 0 : index
687+
%c2 = arith.constant 2 : index
688+
689+
%0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32>
690+
%2 = vector.load %arg0[%c2] : memref<8xi32>, vector<2xi32>
691+
692+
%4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32>
693+
%6 = vector.load %arg1[%c2] : memref<8xi32>, vector<2xi32>
694+
695+
%8 = arith.addi %0, %4 : vector<2xi32>
696+
%10 = arith.addi %2, %6 : vector<2xi32>
697+
698+
return %8, %10 : vector<2xi32>, vector<2xi32>
699+
}
700+
675701

676702
func.func private @use(i32)
677703

0 commit comments

Comments
 (0)