Skip to content

Commit b805b21

Browse files
committed
merge vectorized ops too
1 parent 8f63e1e commit b805b21

File tree

2 files changed

+105
-58
lines changed

2 files changed

+105
-58
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct MemoryOpGroup {
4343
enum class Type { Load, Store };
4444
Type type;
4545
SmallVector<Operation *> ops;
46+
int64_t elementsCount = 0;
4647

4748
MemoryOpGroup(Type t) : type(t) {}
4849

@@ -68,30 +69,37 @@ static bool maybeWriteOp(Operation *op) {
6869
return effectInterface.hasEffect<MemoryEffects::Write>();
6970
}
7071

71-
static Type getVectorElementType(VectorType vectorType) {
72-
if (vectorType.getRank() > 1 || vectorType.isScalable() ||
73-
vectorType.getNumElements() != 1)
74-
return {};
72+
static std::optional<std::pair<Type, int64_t>>
73+
getVectorElementTypeAndCount(VectorType vectorType) {
74+
if (vectorType.getRank() > 1 || vectorType.isScalable())
75+
return std::nullopt;
7576

76-
return vectorType.getElementType();
77+
return std::make_pair(vectorType.getElementType(),
78+
vectorType.getNumElements());
7779
}
7880

79-
static Type getElementType(Operation *op) {
81+
static std::optional<std::pair<Type, int64_t>>
82+
getElementTypeAndCount(Operation *op) {
8083
assert(op && "null op");
8184
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
82-
return loadOp.getResult().getType();
85+
return std::make_pair(loadOp.getResult().getType(), 1);
8386
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
84-
return storeOp.getValueToStore().getType();
87+
return std::make_pair(storeOp.getValueToStore().getType(), 1);
8588
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
86-
return getVectorElementType(loadOp.getVectorType());
89+
return getVectorElementTypeAndCount(loadOp.getVectorType());
8790
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
88-
return getVectorElementType(storeOp.getVectorType());
89-
return {};
91+
return getVectorElementTypeAndCount(storeOp.getVectorType());
92+
return std::nullopt;
9093
}
9194

9295
static bool isSupportedMemOp(Operation *op) {
9396
assert(op && "null op");
94-
return isa_and_present<IntegerType, FloatType, IndexType>(getElementType(op));
97+
auto typeAndCount = getElementTypeAndCount(op);
98+
if (!typeAndCount)
99+
return false;
100+
101+
return isa_and_present<IntegerType, FloatType, IndexType>(
102+
typeAndCount->first);
95103
}
96104

97105
/// Collect all memory operations in the block into groups.
@@ -177,7 +185,7 @@ static ValueRange getIndices(Operation *op) {
177185
return {};
178186
}
179187

180-
static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
188+
static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
181189
auto applyOp1 = idx1.getDefiningOp<affine::AffineApplyOp>();
182190
if (!applyOp1)
183191
return false;
@@ -195,48 +203,52 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
195203
simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size());
196204

197205
auto diffConst = dyn_cast<AffineConstantExpr>(diff);
198-
return diffConst && diffConst.getValue() == 1;
206+
return diffConst && diffConst.getValue() == offset;
199207
}
200208

201209
/// Check if two indices are consecutive, i.e index1 + 1 == index2.
202-
static bool isAdjacentIndices(Value idx1, Value idx2) {
210+
static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
203211
if (auto c1 = getConstantIntValue(idx1)) {
204212
if (auto c2 = getConstantIntValue(idx2))
205-
return *c1 + 1 == *c2;
213+
return *c1 + offset == *c2;
206214
}
207215

208216
if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
209-
if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
217+
if (addOp2.getLhs() == idx1 &&
218+
getConstantIntValue(addOp2.getRhs()) == offset)
210219
return true;
211220

212221
if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
213222
if (addOp1.getLhs() == addOp2.getLhs() &&
214-
isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
223+
isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs(), offset))
215224
return true;
216225
}
217226
}
218227

219-
if (isAdjacentAffineMapIndices(idx1, idx2))
228+
if (isAdjacentAffineMapIndices(idx1, idx2, offset))
220229
return true;
221230

222231
return false;
223232
}
224233

225234
/// Check if two ranges of indices are consecutive, i.e fastest index differs
226235
/// by 1 and all other indices are the same.
227-
static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
236+
static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2,
237+
int64_t offset) {
228238
if (idx1.empty() || idx1.size() != idx2.size())
229239
return false;
230240

231241
if (idx1.drop_back() != idx2.drop_back())
232242
return false;
233243

234-
return isAdjacentIndices(idx1.back(), idx2.back());
244+
return isAdjacentIndices(idx1.back(), idx2.back(), offset);
235245
}
236246

237247
/// Check if two operations are adjacent and can be combined into a vector op.
238248
/// This is done by checking if the base memrefs are the same, the last
239-
/// dimension is contiguous, and the element types and indices are compatible
249+
/// dimension is contiguous, and the element types and indices are compatible.
250+
/// If source read/write is already vectorized, only merge ops if vector
251+
/// elements count is the same.
240252
static bool isAdjacentOps(Operation *op1, Operation *op2) {
241253
assert(op1 && "null op1");
242254
assert(op2 && "null op2");
@@ -249,10 +261,19 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
249261
if (!isContiguousLastDim(base1))
250262
return false;
251263

252-
if (getElementType(op1) != getElementType(op2))
264+
auto typeAndCount1 = getElementTypeAndCount(op1);
265+
if (!typeAndCount1)
266+
return false;
267+
268+
auto typeAndCount2 = getElementTypeAndCount(op2);
269+
if (!typeAndCount2)
253270
return false;
254271

255-
return isAdjacentIndices(getIndices(op1), getIndices(op2));
272+
if (typeAndCount1 != typeAndCount2)
273+
return false;
274+
275+
return isAdjacentIndices(getIndices(op1), getIndices(op2),
276+
typeAndCount1->second);
256277
}
257278

258279
// Extract contiguous groups from a MemoryOpGroup
@@ -271,6 +292,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
271292
// Start a new group with this operation
272293
result.emplace_back(group.type);
273294
MemoryOpGroup &currentGroup = result.back();
295+
currentGroup.elementsCount = getElementTypeAndCount(op)->second;
274296
auto &currentOps = currentGroup.ops;
275297
currentOps.push_back(op);
276298
processedOps.insert(op);
@@ -310,7 +332,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
310332
return result;
311333
}
312334

313-
static bool isVectorizable(Operation *op) {
335+
static bool
336+
isVectorizable(Operation *op,
337+
std::optional<int64_t> expectedElementsCount = std::nullopt) {
314338
if (!OpTrait::hasElementwiseMappableTraits(op))
315339
return false;
316340

@@ -319,14 +343,18 @@ static bool isVectorizable(Operation *op) {
319343

320344
for (auto type :
321345
llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
346+
int64_t vectorElementsCount = 1;
322347
if (auto vectorType = dyn_cast<VectorType>(type)) {
323-
if (vectorType.getRank() > 1 || vectorType.isScalable() ||
324-
vectorType.getNumElements() != 1)
348+
if (vectorType.getRank() > 1 || vectorType.isScalable())
325349
return false;
326350

327351
type = vectorType.getElementType();
352+
vectorElementsCount = vectorType.getNumElements();
328353
}
329354

355+
if (expectedElementsCount && vectorElementsCount != *expectedElementsCount)
356+
return false;
357+
330358
if (!isa<IntegerType, FloatType, IndexType>(type))
331359
return false;
332360
}
@@ -347,13 +375,15 @@ struct SLPGraphNode {
347375
SmallVector<SLPGraphNode *> users;
348376
SmallVector<SLPGraphNode *> operands;
349377
Operation *insertionPoint = nullptr;
378+
int64_t elementsCount = 0;
350379
bool isRoot = false;
351380

352381
SLPGraphNode() = default;
353382
SLPGraphNode(ArrayRef<Operation *> operations)
354383
: ops(operations.begin(), operations.end()) {}
355384

356385
size_t opsCount() const { return ops.size(); }
386+
size_t vectorSize() const { return elementsCount * opsCount(); }
357387

358388
Operation *op() const {
359389
assert(!ops.empty() && "empty ops");
@@ -415,17 +445,20 @@ class SLPGraph {
415445
SLPGraph &operator=(SLPGraph &&) = default;
416446

417447
/// Add a new node to the graph
418-
SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
448+
SLPGraphNode *addNode(ArrayRef<Operation *> operations,
449+
int64_t elementsCount) {
419450
nodes.push_back(std::make_unique<SLPGraphNode>(operations));
420451
auto *node = nodes.back().get();
452+
node->elementsCount = elementsCount;
421453
for (Operation *op : operations)
422454
opToNode[op] = node;
423455
return node;
424456
}
425457

426458
/// Add a root node (memory operation)
427-
SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
428-
auto *node = addNode(operations);
459+
SLPGraphNode *addRoot(ArrayRef<Operation *> operations,
460+
int64_t elementsCount) {
461+
auto *node = addNode(operations, elementsCount);
429462
node->isRoot = true;
430463
return node;
431464
}
@@ -699,13 +732,14 @@ static bool
699732
checkOpVecType(SLPGraphNode *node,
700733
llvm::function_ref<bool(Type, size_t)> isValidVecType) {
701734
Operation *op = node->op();
702-
size_t size = node->opsCount();
735+
size_t size = node->vectorSize();
703736
auto checkRes = [](bool res) -> bool {
704737
LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
705738
return res;
706739
};
707740

708-
if (Type elementType = getElementType(op)) {
741+
if (auto typeAndCount = getElementTypeAndCount(op)) {
742+
Type elementType = typeAndCount->first;
709743
LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType
710744
<< " with size " << size << " can be vectorized: ");
711745
return checkRes(isValidVecType(elementType, size));
@@ -777,7 +811,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
777811

778812
// First, create nodes for each contiguous memory operation group
779813
for (const auto &group : rootGroups) {
780-
auto *node = graph.addRoot(group.ops);
814+
auto *node = graph.addRoot(group.ops, group.elementsCount);
781815
worklist.push_back(node);
782816

783817
LLVM_DEBUG({
@@ -800,7 +834,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
800834
return;
801835
}
802836

803-
if (!isVectorizable(user))
837+
if (!isVectorizable(user, node->elementsCount))
804838
return;
805839

806840
Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
@@ -830,7 +864,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
830864
if (currentOps.size() == 1)
831865
return;
832866

833-
auto *newNode = graph.addNode(currentOps);
867+
auto *newNode = graph.addNode(currentOps, node->elementsCount);
834868
graph.addEdge(node, newNode);
835869
for (Operation *op : currentOps)
836870
fingerprints.invalidate(op);
@@ -877,7 +911,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
877911
currentOps.push_back(otherOp);
878912
++currentIndex;
879913
}
880-
} else if (isVectorizable(srcOp)) {
914+
} else if (isVectorizable(srcOp, node->elementsCount)) {
881915
LLVM_DEBUG(llvm::dbgs() << " Processing vectorizable op "
882916
<< srcOp->getName() << "\n");
883917

@@ -898,7 +932,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
898932
if (currentOps.size() == 1)
899933
return;
900934

901-
auto *newNode = graph.addNode(currentOps);
935+
auto *newNode = graph.addNode(currentOps, node->elementsCount);
902936
graph.addEdge(newNode, node);
903937
for (Operation *op : currentOps)
904938
fingerprints.invalidate(op);
@@ -1000,7 +1034,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10001034
LLVM_DEBUG(llvm::dbgs() << " Insertion point: " << *ip << "\n");
10011035

10021036
rewriter.setInsertionPoint(ip);
1003-
int64_t numElements = node->opsCount();
1037+
int64_t numElements = node->vectorSize();
10041038
Location loc = op->getLoc();
10051039

10061040
auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -1009,10 +1043,20 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10091043
continue;
10101044

10111045
SmallVector<Value> args;
1012-
for (Operation *defOp : node->ops)
1013-
args.push_back(defOp->getOperand(i));
1046+
for (Operation *defOp : node->ops) {
1047+
Value arg = defOp->getOperand(i);
1048+
if (auto vecType = dyn_cast<VectorType>(arg.getType())) {
1049+
assert(vecType.getRank() == 1);
1050+
for (auto j : llvm::seq(vecType.getNumElements()))
1051+
args.push_back(rewriter.create<vector::ExtractOp>(loc, arg, j));
1052+
1053+
} else {
1054+
args.push_back(arg);
1055+
}
1056+
}
10141057

1015-
auto vecType = VectorType::get(numElements, operand.getType());
1058+
auto vecType = VectorType::get(numElements,
1059+
getElementTypeOrSelf(operand.getType()));
10161060
Value vector =
10171061
rewriter.create<vector::FromElementsOp>(loc, vecType, args);
10181062
mapping.map(operand, vector);
@@ -1043,7 +1087,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10431087
};
10441088

10451089
if (maybeReadOp(op)) {
1046-
auto vecType = VectorType::get(numElements, getElementType(op));
1090+
auto vecType =
1091+
VectorType::get(numElements, getElementTypeAndCount(op)->first);
10471092
Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
10481093
getIndices(op));
10491094
mapping.map(op->getResult(0), result);

mlir/test/Dialect/Vector/slp-vectorize.mlir

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -646,22 +646,24 @@ func.func private @use(i32)
646646
// CHECK-LABEL: func @read_read_add_write_interleaved_use
647647
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
648648
func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
649-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
650-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
651-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
652-
// CHECK: %[[V0:.*]] = memref.load %arg0[%[[C3]]] : memref<8xi32>
653-
// CHECK: %[[V1:.*]] = memref.load %arg1[%[[C3]]] : memref<8xi32>
654-
// CHECK: call @use(%[[V0]]) : (i32) -> ()
655-
// CHECK: %[[V2:.*]] = vector.load %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
656-
// CHECK: %[[V3:.*]] = vector.load %arg1[%[[C0]]] : memref<8xi32>, vector<2xi32>
657-
// CHECK: %[[V4:.*]] = memref.load %arg0[%[[C2]]] : memref<8xi32>
658-
// CHECK: %[[V5:.*]] = memref.load %arg1[%[[C2]]] : memref<8xi32>
659-
// CHECK: %[[V6:.*]] = vector.from_elements %[[V4]], %[[V0]] : vector<2xi32>
660-
// CHECK: %[[V7:.*]] = vector.from_elements %[[V5]], %[[V1]] : vector<2xi32>
661-
// CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
662-
// CHECK: %[[V9:.*]] = arith.addi %[[V2]], %[[V3]] : vector<2xi32>
663-
// CHECK: vector.store %[[V9]], %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
664-
// CHECK: vector.store %[[V8]], %arg0[%[[C2]]] : memref<8xi32>, vector<2xi32>
649+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
650+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
651+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
652+
// CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
653+
// CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
654+
// CHECK: call @use(%[[V0]]) : (i32) -> ()
655+
// CHECK: %[[V2:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
656+
// CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
657+
// CHECK: %[[V4:.*]] = memref.load %[[ARG0]][%[[C2]]] : memref<8xi32>
658+
// CHECK: %[[V5:.*]] = memref.load %[[ARG1]][%[[C2]]] : memref<8xi32>
659+
// CHECK: %[[V6:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32>
660+
// CHECK: %[[V7:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32>
661+
// CHECK: %[[V8:.*]] = vector.from_elements %[[V6]], %[[V7]], %[[V4]], %[[V0]] : vector<4xi32>
662+
// CHECK: %[[V9:.*]] = vector.extract %[[V3]][0] : i32 from vector<2xi32>
663+
// CHECK: %[[V10:.*]] = vector.extract %[[V3]][1] : i32 from vector<2xi32>
664+
// CHECK: %[[V11:.*]] = vector.from_elements %[[V9]], %[[V10]], %[[V5]], %[[V1]] : vector<4xi32>
665+
// CHECK: %[[V12:.*]] = arith.addi %[[V8]], %[[V11]] : vector<4xi32>
666+
// CHECK: vector.store %[[V12]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
665667

666668
%c0 = arith.constant 0 : index
667669
%c1 = arith.constant 1 : index

0 commit comments

Comments
 (0)