Skip to content

Commit be553a4

Browse files
committed
support for 1-element vectors
1 parent 2b82f23 commit be553a4

File tree

2 files changed

+168
-32
lines changed

2 files changed

+168
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 94 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,32 @@ static bool maybeWriteOp(Operation *op) {
6868
return effectInterface.hasEffect<MemoryEffects::Write>();
6969
}
7070

71+
static Type getVectorElementType(VectorType vectorType) {
72+
if (vectorType.getRank() > 1 || vectorType.isScalable() ||
73+
vectorType.getNumElements() != 1)
74+
return {};
75+
76+
return vectorType.getElementType();
77+
}
78+
79+
static Type getElementType(Operation *op) {
80+
assert(op && "null op");
81+
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
82+
return loadOp.getResult().getType();
83+
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
84+
return storeOp.getValueToStore().getType();
85+
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
86+
return getVectorElementType(loadOp.getVectorType());
87+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
88+
return getVectorElementType(storeOp.getVectorType());
89+
return {};
90+
}
91+
92+
static bool isSupportedMemOp(Operation *op) {
93+
assert(op && "null op");
94+
return isa_and_present<IntegerType, FloatType, IndexType>(getElementType(op));
95+
}
96+
7197
/// Collect all memory operations in the block into groups.
7298
/// Each group contains either all loads or all stores, uninterrupted by
7399
/// operations of the other type.
@@ -85,7 +111,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
85111
}
86112
}
87113

88-
if (!isa<memref::LoadOp, memref::StoreOp>(op))
114+
if (!isSupportedMemOp(&op))
89115
continue;
90116

91117
bool isLoad = maybeReadOp(&op);
@@ -109,6 +135,19 @@ static Value getBase(Operation *op) {
109135
return loadOp.getMemRef();
110136
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
111137
return storeOp.getMemRef();
138+
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
139+
return loadOp.getBase();
140+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
141+
return storeOp.getBase();
142+
return {};
143+
}
144+
145+
static Value getValueToStore(Operation *op) {
146+
assert(op && "null op");
147+
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
148+
return storeOp.getValueToStore();
149+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
150+
return storeOp.getValueToStore();
112151
return {};
113152
}
114153

@@ -131,15 +170,10 @@ static ValueRange getIndices(Operation *op) {
131170
return loadOp.getIndices();
132171
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
133172
return storeOp.getIndices();
134-
return {};
135-
}
136-
137-
static Type getElementType(Operation *op) {
138-
assert(op && "null op");
139-
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
140-
return loadOp.getResult().getType();
141-
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
142-
return storeOp.getValueToStore().getType();
173+
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
174+
return loadOp.getIndices();
175+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
176+
return storeOp.getIndices();
143177
return {};
144178
}
145179

@@ -285,7 +319,15 @@ static bool isVectorizable(Operation *op) {
285319

286320
for (auto type :
287321
llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
288-
if (!type.isIntOrIndexOrFloat())
322+
if (auto vectorType = dyn_cast<VectorType>(type)) {
323+
if (vectorType.getRank() > 1 || vectorType.isScalable() ||
324+
vectorType.getNumElements() != 1)
325+
return false;
326+
327+
type = vectorType.getElementType();
328+
}
329+
330+
if (!isa<IntegerType, FloatType, IndexType>(type))
289331
return false;
290332
}
291333

@@ -464,8 +506,7 @@ class SLPGraph {
464506
for (const auto &node : nodes) {
465507
if (!node->isRoot)
466508
continue;
467-
llvm::dbgs() << " "
468-
<< (isa<memref::LoadOp>(node->op()) ? "LOAD" : "STORE")
509+
llvm::dbgs() << " " << (maybeReadOp(node->op()) ? "LOAD" : "STORE")
469510
<< " group with " << node->size() << " operations:\n";
470511
for (auto *op : node->ops) {
471512
llvm::dbgs() << " " << *op << "\n";
@@ -657,20 +698,36 @@ checkOpVecType(SLPGraphNode *node,
657698
llvm::function_ref<bool(Type, size_t)> isValidVecType) {
658699
Operation *op = node->op();
659700
size_t size = node->size();
660-
if (Type elementType = getElementType(op))
661-
return isValidVecType(elementType, size);
701+
auto checkRes = [](bool res) -> bool {
702+
LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
703+
return res;
704+
};
705+
706+
if (Type elementType = getElementType(op)) {
707+
LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType
708+
<< " with size " << size << " can be vectorized: ");
709+
return checkRes(isValidVecType(elementType, size));
710+
}
662711

663712
if (isVectorizable(op)) {
664713
for (auto type :
665714
llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
666-
if (!isValidVecType(type, size))
715+
Type elementType = getElementTypeOrSelf(type);
716+
LLVM_DEBUG(llvm::dbgs()
717+
<< "Checking if type " << elementType << " with size " << size
718+
<< " can be vectorized: ");
719+
if (!checkRes(isValidVecType(elementType, size)))
667720
return false;
668721
}
669722
return true;
670723
}
671724

672-
if (auto extract = dyn_cast<vector::ExtractOp>(op))
673-
return isValidVecType(extract.getResult().getType(), size);
725+
if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
726+
Type type = extract.getResult().getType();
727+
LLVM_DEBUG(llvm::dbgs() << "Checking if type " << type << " with size "
728+
<< size << " can be vectorized: ");
729+
return checkRes(isValidVecType(type, size));
730+
}
674731

675732
LLVM_DEBUG(llvm::dbgs() << "Unsupported op " << op->getName() << "\n");
676733
return false;
@@ -903,12 +960,19 @@ SLPGraph::vectorize(IRRewriter &rewriter,
903960
for (auto *operand : node->operands)
904961
size = std::min(size, operand->size());
905962

906-
node->ops.resize(size);
963+
if (size < node->size()) {
964+
LLVM_DEBUG(llvm::dbgs()
965+
<< "Size mismatch, resizing node with " << node->size()
966+
<< " operations to " << size << "\n");
967+
node->ops.resize(size);
968+
}
907969

908970
while (node->size() > 1) {
909971
if (checkOpVecType(node, isValidVecType))
910972
break;
911973

974+
LLVM_DEBUG(llvm::dbgs() << "No a valid vector type, popping back op: "
975+
<< node->ops.back()->getName() << "\n");
912976
node->ops.pop_back();
913977
}
914978
}
@@ -975,24 +1039,22 @@ SLPGraph::vectorize(IRRewriter &rewriter,
9751039
numElements, 1);
9761040
};
9771041

978-
if (auto load = dyn_cast<memref::LoadOp>(op)) {
979-
auto vecType =
980-
VectorType::get(numElements, load.getMemRefType().getElementType());
981-
Value result = rewriter.create<vector::LoadOp>(
982-
loc, vecType, load.getMemRef(), load.getIndices());
983-
mapping.map(load.getResult(), result);
1042+
if (maybeReadOp(op)) {
1043+
auto vecType = VectorType::get(numElements, getElementType(op));
1044+
Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
1045+
getIndices(op));
1046+
mapping.map(op->getResult(0), result);
9841047
handleNonVectorOutputs(result);
985-
} else if (auto store = dyn_cast<memref::StoreOp>(op)) {
986-
handleNonVectorInputs(store.getValueToStore());
987-
Value val = mapping.lookupOrDefault(store.getValueToStore());
1048+
} else if (maybeWriteOp(op)) {
1049+
handleNonVectorInputs(getValueToStore(op));
1050+
Value val = mapping.lookupOrDefault(getValueToStore(op));
9881051
val = handleVecSizeMismatch(val);
989-
rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
990-
store.getIndices());
1052+
rewriter.create<vector::StoreOp>(loc, val, getBase(op), getIndices(op));
9911053
} else if (isVectorizable(op)) {
9921054
handleNonVectorInputs(op->getOperands());
9931055
Operation *newOp = rewriter.clone(*op, mapping);
994-
auto resVectorType =
995-
VectorType::get(numElements, op->getResultTypes().front());
1056+
Type resType = getElementTypeOrSelf(op->getResultTypes().front());
1057+
auto resVectorType = VectorType::get(numElements, resType);
9961058

9971059
{
9981060
OpBuilder::InsertionGuard guard(rewriter);

mlir/test/Dialect/Vector/slp-vectorize.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,80 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
276276
}
277277

278278

279+
// CHECK-LABEL: func @read_read_add_write_vec_0d
280+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
281+
func.func @read_read_add_write_vec_0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
282+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
283+
// CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
284+
// CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
285+
// CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
286+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
287+
%c0 = arith.constant 0 : index
288+
%c1 = arith.constant 1 : index
289+
%c2 = arith.constant 2 : index
290+
%c3 = arith.constant 3 : index
291+
292+
%0 = vector.load %arg0[%c0] : memref<8xi32>, vector<i32>
293+
%1 = vector.load %arg0[%c1] : memref<8xi32>, vector<i32>
294+
%2 = vector.load %arg0[%c2] : memref<8xi32>, vector<i32>
295+
%3 = vector.load %arg0[%c3] : memref<8xi32>, vector<i32>
296+
297+
%4 = vector.load %arg1[%c0] : memref<8xi32>, vector<i32>
298+
%5 = vector.load %arg1[%c1] : memref<8xi32>, vector<i32>
299+
%6 = vector.load %arg1[%c2] : memref<8xi32>, vector<i32>
300+
%7 = vector.load %arg1[%c3] : memref<8xi32>, vector<i32>
301+
302+
%8 = arith.addi %0, %4 : vector<i32>
303+
%9 = arith.addi %1, %5 : vector<i32>
304+
%10 = arith.addi %2, %6 : vector<i32>
305+
%11 = arith.addi %3, %7 : vector<i32>
306+
307+
vector.store %8, %arg0[%c0] : memref<8xi32>, vector<i32>
308+
vector.store %9, %arg0[%c1] : memref<8xi32>, vector<i32>
309+
vector.store %10, %arg0[%c2] : memref<8xi32>, vector<i32>
310+
vector.store %11, %arg0[%c3] : memref<8xi32>, vector<i32>
311+
312+
return
313+
}
314+
315+
316+
// CHECK-LABEL: func @read_read_add_write_vec_1d
317+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
318+
func.func @read_read_add_write_vec_1d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
319+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
320+
// CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
321+
// CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
322+
// CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
323+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
324+
%c0 = arith.constant 0 : index
325+
%c1 = arith.constant 1 : index
326+
%c2 = arith.constant 2 : index
327+
%c3 = arith.constant 3 : index
328+
329+
%0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32>
330+
%1 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32>
331+
%2 = vector.load %arg0[%c2] : memref<8xi32>, vector<1xi32>
332+
%3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32>
333+
334+
%4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32>
335+
%5 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32>
336+
%6 = vector.load %arg1[%c2] : memref<8xi32>, vector<1xi32>
337+
%7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32>
338+
339+
%8 = arith.addi %0, %4 : vector<1xi32>
340+
%9 = arith.addi %1, %5 : vector<1xi32>
341+
%10 = arith.addi %2, %6 : vector<1xi32>
342+
%11 = arith.addi %3, %7 : vector<1xi32>
343+
344+
vector.store %8, %arg0[%c0] : memref<8xi32>, vector<1xi32>
345+
vector.store %9, %arg0[%c1] : memref<8xi32>, vector<1xi32>
346+
vector.store %10, %arg0[%c2] : memref<8xi32>, vector<1xi32>
347+
vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32>
348+
349+
return
350+
}
351+
352+
279353
// CHECK-LABEL: func @read_read_add_write_seven
280354
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
281355
func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {

0 commit comments

Comments
 (0)