@@ -43,6 +43,7 @@ struct MemoryOpGroup {
4343 enum class Type { Load, Store };
4444 Type type;
4545 SmallVector<Operation *> ops;
46+ int64_t elementsCount = 0 ;
4647
4748 MemoryOpGroup (Type t) : type(t) {}
4849
@@ -68,30 +69,37 @@ static bool maybeWriteOp(Operation *op) {
6869 return effectInterface.hasEffect <MemoryEffects::Write>();
6970}
7071
71- static Type getVectorElementType (VectorType vectorType) {
72- if (vectorType. getRank () > 1 || vectorType. isScalable () ||
73- vectorType.getNumElements () != 1 )
74- return {} ;
72+ static std::optional<std::pair< Type, int64_t >>
73+ getVectorElementTypeAndCount (VectorType vectorType) {
74+ if ( vectorType.getRank () > 1 || vectorType. isScalable () )
75+ return std:: nullopt ;
7576
76- return vectorType.getElementType ();
77+ return std::make_pair (vectorType.getElementType (),
78+ vectorType.getNumElements ());
7779}
7880
79- static Type getElementType (Operation *op) {
81+ static std::optional<std::pair<Type, int64_t >>
82+ getElementTypeAndCount (Operation *op) {
8083 assert (op && " null op" );
8184 if (auto loadOp = dyn_cast<memref::LoadOp>(op))
82- return loadOp.getResult ().getType ();
85+ return std::make_pair ( loadOp.getResult ().getType (), 1 );
8386 if (auto storeOp = dyn_cast<memref::StoreOp>(op))
84- return storeOp.getValueToStore ().getType ();
87+ return std::make_pair ( storeOp.getValueToStore ().getType (), 1 );
8588 if (auto loadOp = dyn_cast<vector::LoadOp>(op))
86- return getVectorElementType (loadOp.getVectorType ());
89+ return getVectorElementTypeAndCount (loadOp.getVectorType ());
8790 if (auto storeOp = dyn_cast<vector::StoreOp>(op))
88- return getVectorElementType (storeOp.getVectorType ());
89- return {} ;
91+ return getVectorElementTypeAndCount (storeOp.getVectorType ());
92+ return std:: nullopt ;
9093}
9194
9295static bool isSupportedMemOp (Operation *op) {
9396 assert (op && " null op" );
94- return isa_and_present<IntegerType, FloatType, IndexType>(getElementType (op));
97+ auto typeAndCount = getElementTypeAndCount (op);
98+ if (!typeAndCount)
99+ return false ;
100+
101+ return isa_and_present<IntegerType, FloatType, IndexType>(
102+ typeAndCount->first );
95103}
96104
97105// / Collect all memory operations in the block into groups.
@@ -177,7 +185,7 @@ static ValueRange getIndices(Operation *op) {
177185 return {};
178186}
179187
180- static bool isAdjacentAffineMapIndices (Value idx1, Value idx2) {
188+ static bool isAdjacentAffineMapIndices (Value idx1, Value idx2, int64_t offset ) {
181189 auto applyOp1 = idx1.getDefiningOp <affine::AffineApplyOp>();
182190 if (!applyOp1)
183191 return false ;
@@ -195,48 +203,52 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
195203 simplifyAffineExpr (expr2 - expr1, 0 , applyOp1.getOperands ().size ());
196204
197205 auto diffConst = dyn_cast<AffineConstantExpr>(diff);
198- return diffConst && diffConst.getValue () == 1 ;
206+ return diffConst && diffConst.getValue () == offset ;
199207}
200208
201209// / Check if two indices are consecutive, i.e index1 + 1 == index2.
202- static bool isAdjacentIndices (Value idx1, Value idx2) {
210+ static bool isAdjacentIndices (Value idx1, Value idx2, int64_t offset ) {
203211 if (auto c1 = getConstantIntValue (idx1)) {
204212 if (auto c2 = getConstantIntValue (idx2))
205- return *c1 + 1 == *c2;
213+ return *c1 + offset == *c2;
206214 }
207215
208216 if (auto addOp2 = idx2.getDefiningOp <arith::AddIOp>()) {
209- if (addOp2.getLhs () == idx1 && getConstantIntValue (addOp2.getRhs ()) == 1 )
217+ if (addOp2.getLhs () == idx1 &&
218+ getConstantIntValue (addOp2.getRhs ()) == offset)
210219 return true ;
211220
212221 if (auto addOp1 = idx1.getDefiningOp <arith::AddIOp>()) {
213222 if (addOp1.getLhs () == addOp2.getLhs () &&
214- isAdjacentIndices (addOp1.getRhs (), addOp2.getRhs ()))
223+ isAdjacentIndices (addOp1.getRhs (), addOp2.getRhs (), offset ))
215224 return true ;
216225 }
217226 }
218227
219- if (isAdjacentAffineMapIndices (idx1, idx2))
228+ if (isAdjacentAffineMapIndices (idx1, idx2, offset ))
220229 return true ;
221230
222231 return false ;
223232}
224233
225234// / Check if two ranges of indices are consecutive, i.e fastest index differs
226235// / by 1 and all other indices are the same.
227- static bool isAdjacentIndices (ValueRange idx1, ValueRange idx2) {
236+ static bool isAdjacentIndices (ValueRange idx1, ValueRange idx2,
237+ int64_t offset) {
228238 if (idx1.empty () || idx1.size () != idx2.size ())
229239 return false ;
230240
231241 if (idx1.drop_back () != idx2.drop_back ())
232242 return false ;
233243
234- return isAdjacentIndices (idx1.back (), idx2.back ());
244+ return isAdjacentIndices (idx1.back (), idx2.back (), offset );
235245}
236246
237247// / Check if two operations are adjacent and can be combined into a vector op.
238248// / This is done by checking if the base memrefs are the same, the last
239- // / dimension is contiguous, and the element types and indices are compatible
249+ // / dimension is contiguous, and the element types and indices are compatible.
250+ // / If source read/write is already vectorized, only merge ops if vector
251+ // / elements count is the same.
240252static bool isAdjacentOps (Operation *op1, Operation *op2) {
241253 assert (op1 && " null op1" );
242254 assert (op2 && " null op2" );
@@ -249,10 +261,19 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
249261 if (!isContiguousLastDim (base1))
250262 return false ;
251263
252- if (getElementType (op1) != getElementType (op2))
264+ auto typeAndCount1 = getElementTypeAndCount (op1);
265+ if (!typeAndCount1)
266+ return false ;
267+
268+ auto typeAndCount2 = getElementTypeAndCount (op2);
269+ if (!typeAndCount2)
253270 return false ;
254271
255- return isAdjacentIndices (getIndices (op1), getIndices (op2));
272+ if (typeAndCount1 != typeAndCount2)
273+ return false ;
274+
275+ return isAdjacentIndices (getIndices (op1), getIndices (op2),
276+ typeAndCount1->second );
256277}
257278
258279// Extract contiguous groups from a MemoryOpGroup
@@ -271,6 +292,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
271292 // Start a new group with this operation
272293 result.emplace_back (group.type );
273294 MemoryOpGroup ¤tGroup = result.back ();
295+ currentGroup.elementsCount = getElementTypeAndCount (op)->second ;
274296 auto ¤tOps = currentGroup.ops ;
275297 currentOps.push_back (op);
276298 processedOps.insert (op);
@@ -310,7 +332,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
310332 return result;
311333}
312334
313- static bool isVectorizable (Operation *op) {
335+ static bool
336+ isVectorizable (Operation *op,
337+ std::optional<int64_t > expectedElementsCount = std::nullopt ) {
314338 if (!OpTrait::hasElementwiseMappableTraits (op))
315339 return false ;
316340
@@ -319,14 +343,18 @@ static bool isVectorizable(Operation *op) {
319343
320344 for (auto type :
321345 llvm::concat<Type>(op->getResultTypes (), op->getOperandTypes ())) {
346+ int64_t vectorElementsCount = 1 ;
322347 if (auto vectorType = dyn_cast<VectorType>(type)) {
323- if (vectorType.getRank () > 1 || vectorType.isScalable () ||
324- vectorType.getNumElements () != 1 )
348+ if (vectorType.getRank () > 1 || vectorType.isScalable ())
325349 return false ;
326350
327351 type = vectorType.getElementType ();
352+ vectorElementsCount = vectorType.getNumElements ();
328353 }
329354
355+ if (expectedElementsCount && vectorElementsCount != *expectedElementsCount)
356+ return false ;
357+
330358 if (!isa<IntegerType, FloatType, IndexType>(type))
331359 return false ;
332360 }
@@ -347,13 +375,15 @@ struct SLPGraphNode {
347375 SmallVector<SLPGraphNode *> users;
348376 SmallVector<SLPGraphNode *> operands;
349377 Operation *insertionPoint = nullptr ;
378+ int64_t elementsCount = 0 ;
350379 bool isRoot = false ;
351380
352381 SLPGraphNode () = default ;
353382 SLPGraphNode (ArrayRef<Operation *> operations)
354383 : ops(operations.begin(), operations.end()) {}
355384
356385 size_t opsCount () const { return ops.size (); }
386+ size_t vectorSize () const { return elementsCount * opsCount (); }
357387
358388 Operation *op () const {
359389 assert (!ops.empty () && " empty ops" );
@@ -415,17 +445,20 @@ class SLPGraph {
415445 SLPGraph &operator =(SLPGraph &&) = default ;
416446
417447 // / Add a new node to the graph
418- SLPGraphNode *addNode (ArrayRef<Operation *> operations) {
448+ SLPGraphNode *addNode (ArrayRef<Operation *> operations,
449+ int64_t elementsCount) {
419450 nodes.push_back (std::make_unique<SLPGraphNode>(operations));
420451 auto *node = nodes.back ().get ();
452+ node->elementsCount = elementsCount;
421453 for (Operation *op : operations)
422454 opToNode[op] = node;
423455 return node;
424456 }
425457
426458 // / Add a root node (memory operation)
427- SLPGraphNode *addRoot (ArrayRef<Operation *> operations) {
428- auto *node = addNode (operations);
459+ SLPGraphNode *addRoot (ArrayRef<Operation *> operations,
460+ int64_t elementsCount) {
461+ auto *node = addNode (operations, elementsCount);
429462 node->isRoot = true ;
430463 return node;
431464 }
@@ -699,13 +732,14 @@ static bool
699732checkOpVecType (SLPGraphNode *node,
700733 llvm::function_ref<bool (Type, size_t )> isValidVecType) {
701734 Operation *op = node->op ();
702- size_t size = node->opsCount ();
735+ size_t size = node->vectorSize ();
703736 auto checkRes = [](bool res) -> bool {
704737 LLVM_DEBUG (llvm::dbgs () << (res ? " true" : " false" ) << " \n " );
705738 return res;
706739 };
707740
708- if (Type elementType = getElementType (op)) {
741+ if (auto typeAndCount = getElementTypeAndCount (op)) {
742+ Type elementType = typeAndCount->first ;
709743 LLVM_DEBUG (llvm::dbgs () << " Checking if type " << elementType
710744 << " with size " << size << " can be vectorized: " );
711745 return checkRes (isValidVecType (elementType, size));
@@ -777,7 +811,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
777811
778812 // First, create nodes for each contiguous memory operation group
779813 for (const auto &group : rootGroups) {
780- auto *node = graph.addRoot (group.ops );
814+ auto *node = graph.addRoot (group.ops , group. elementsCount );
781815 worklist.push_back (node);
782816
783817 LLVM_DEBUG ({
@@ -800,7 +834,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
800834 return ;
801835 }
802836
803- if (!isVectorizable (user))
837+ if (!isVectorizable (user, node-> elementsCount ))
804838 return ;
805839
806840 Fingerprint expectedFingerprint = fingerprints.getFingerprint (user);
@@ -830,7 +864,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
830864 if (currentOps.size () == 1 )
831865 return ;
832866
833- auto *newNode = graph.addNode (currentOps);
867+ auto *newNode = graph.addNode (currentOps, node-> elementsCount );
834868 graph.addEdge (node, newNode);
835869 for (Operation *op : currentOps)
836870 fingerprints.invalidate (op);
@@ -877,7 +911,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
877911 currentOps.push_back (otherOp);
878912 ++currentIndex;
879913 }
880- } else if (isVectorizable (srcOp)) {
914+ } else if (isVectorizable (srcOp, node-> elementsCount )) {
881915 LLVM_DEBUG (llvm::dbgs () << " Processing vectorizable op "
882916 << srcOp->getName () << " \n " );
883917
@@ -898,7 +932,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
898932 if (currentOps.size () == 1 )
899933 return ;
900934
901- auto *newNode = graph.addNode (currentOps);
935+ auto *newNode = graph.addNode (currentOps, node-> elementsCount );
902936 graph.addEdge (newNode, node);
903937 for (Operation *op : currentOps)
904938 fingerprints.invalidate (op);
@@ -1000,7 +1034,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10001034 LLVM_DEBUG (llvm::dbgs () << " Insertion point: " << *ip << " \n " );
10011035
10021036 rewriter.setInsertionPoint (ip);
1003- int64_t numElements = node->opsCount ();
1037+ int64_t numElements = node->vectorSize ();
10041038 Location loc = op->getLoc ();
10051039
10061040 auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -1009,10 +1043,20 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10091043 continue ;
10101044
10111045 SmallVector<Value> args;
1012- for (Operation *defOp : node->ops )
1013- args.push_back (defOp->getOperand (i));
1046+ for (Operation *defOp : node->ops ) {
1047+ Value arg = defOp->getOperand (i);
1048+ if (auto vecType = dyn_cast<VectorType>(arg.getType ())) {
1049+ assert (vecType.getRank () == 1 );
1050+ for (auto j : llvm::seq (vecType.getNumElements ()))
1051+ args.push_back (rewriter.create <vector::ExtractOp>(loc, arg, j));
1052+
1053+ } else {
1054+ args.push_back (arg);
1055+ }
1056+ }
10141057
1015- auto vecType = VectorType::get (numElements, operand.getType ());
1058+ auto vecType = VectorType::get (numElements,
1059+ getElementTypeOrSelf (operand.getType ()));
10161060 Value vector =
10171061 rewriter.create <vector::FromElementsOp>(loc, vecType, args);
10181062 mapping.map (operand, vector);
@@ -1043,7 +1087,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
10431087 };
10441088
10451089 if (maybeReadOp (op)) {
1046- auto vecType = VectorType::get (numElements, getElementType (op));
1090+ auto vecType =
1091+ VectorType::get (numElements, getElementTypeAndCount (op)->first );
10471092 Value result = rewriter.create <vector::LoadOp>(loc, vecType, getBase (op),
10481093 getIndices (op));
10491094 mapping.map (op->getResult (0 ), result);
0 commit comments