@@ -68,6 +68,32 @@ static bool maybeWriteOp(Operation *op) {
6868 return effectInterface.hasEffect <MemoryEffects::Write>();
6969}
7070
71+ static Type getVectorElementType (VectorType vectorType) {
72+ if (vectorType.getRank () > 1 || vectorType.isScalable () ||
73+ vectorType.getNumElements () != 1 )
74+ return {};
75+
76+ return vectorType.getElementType ();
77+ }
78+
79+ static Type getElementType (Operation *op) {
80+ assert (op && " null op" );
81+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
82+ return loadOp.getResult ().getType ();
83+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
84+ return storeOp.getValueToStore ().getType ();
85+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
86+ return getVectorElementType (loadOp.getVectorType ());
87+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
88+ return getVectorElementType (storeOp.getVectorType ());
89+ return {};
90+ }
91+
92+ static bool isSupportedMemOp (Operation *op) {
93+ assert (op && " null op" );
94+ return isa_and_present<IntegerType, FloatType, IndexType>(getElementType (op));
95+ }
96+
7197// / Collect all memory operations in the block into groups.
7298// / Each group contains either all loads or all stores, uninterrupted by
7399// / operations of the other type.
@@ -85,7 +111,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
85111 }
86112 }
87113
88- if (!isa<memref::LoadOp, memref::StoreOp>( op))
114+ if (!isSupportedMemOp (& op))
89115 continue ;
90116
91117 bool isLoad = maybeReadOp (&op);
@@ -109,6 +135,19 @@ static Value getBase(Operation *op) {
109135 return loadOp.getMemRef ();
110136 if (auto storeOp = dyn_cast<memref::StoreOp>(op))
111137 return storeOp.getMemRef ();
138+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
139+ return loadOp.getBase ();
140+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
141+ return storeOp.getBase ();
142+ return {};
143+ }
144+
145+ static Value getValueToStore (Operation *op) {
146+ assert (op && " null op" );
147+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
148+ return storeOp.getValueToStore ();
149+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
150+ return storeOp.getValueToStore ();
112151 return {};
113152}
114153
@@ -131,15 +170,10 @@ static ValueRange getIndices(Operation *op) {
131170 return loadOp.getIndices ();
132171 if (auto storeOp = dyn_cast<memref::StoreOp>(op))
133172 return storeOp.getIndices ();
134- return {};
135- }
136-
137- static Type getElementType (Operation *op) {
138- assert (op && " null op" );
139- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
140- return loadOp.getResult ().getType ();
141- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
142- return storeOp.getValueToStore ().getType ();
173+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
174+ return loadOp.getIndices ();
175+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
176+ return storeOp.getIndices ();
143177 return {};
144178}
145179
@@ -285,7 +319,15 @@ static bool isVectorizable(Operation *op) {
285319
286320 for (auto type :
287321 llvm::concat<Type>(op->getResultTypes (), op->getOperandTypes ())) {
288- if (!type.isIntOrIndexOrFloat ())
322+ if (auto vectorType = dyn_cast<VectorType>(type)) {
323+ if (vectorType.getRank () > 1 || vectorType.isScalable () ||
324+ vectorType.getNumElements () != 1 )
325+ return false ;
326+
327+ type = vectorType.getElementType ();
328+ }
329+
330+ if (!isa<IntegerType, FloatType, IndexType>(type))
289331 return false ;
290332 }
291333
@@ -464,8 +506,7 @@ class SLPGraph {
464506 for (const auto &node : nodes) {
465507 if (!node->isRoot )
466508 continue ;
467- llvm::dbgs () << " "
468- << (isa<memref::LoadOp>(node->op ()) ? " LOAD" : " STORE" )
509+ llvm::dbgs () << " " << (maybeReadOp (node->op ()) ? " LOAD" : " STORE" )
469510 << " group with " << node->size () << " operations:\n " ;
470511 for (auto *op : node->ops ) {
471512 llvm::dbgs () << " " << *op << " \n " ;
@@ -657,20 +698,36 @@ checkOpVecType(SLPGraphNode *node,
657698 llvm::function_ref<bool (Type, size_t )> isValidVecType) {
658699 Operation *op = node->op ();
659700 size_t size = node->size ();
660- if (Type elementType = getElementType (op))
661- return isValidVecType (elementType, size);
701+ auto checkRes = [](bool res) -> bool {
702+ LLVM_DEBUG (llvm::dbgs () << (res ? " true" : " false" ) << " \n " );
703+ return res;
704+ };
705+
706+ if (Type elementType = getElementType (op)) {
707+ LLVM_DEBUG (llvm::dbgs () << " Checking if type " << elementType
708+ << " with size " << size << " can be vectorized: " );
709+ return checkRes (isValidVecType (elementType, size));
710+ }
662711
663712 if (isVectorizable (op)) {
664713 for (auto type :
665714 llvm::concat<Type>(op->getResultTypes (), op->getOperandTypes ())) {
666- if (!isValidVecType (type, size))
715+ Type elementType = getElementTypeOrSelf (type);
716+ LLVM_DEBUG (llvm::dbgs ()
717+ << " Checking if type " << elementType << " with size " << size
718+ << " can be vectorized: " );
719+ if (!checkRes (isValidVecType (elementType, size)))
667720 return false ;
668721 }
669722 return true ;
670723 }
671724
672- if (auto extract = dyn_cast<vector::ExtractOp>(op))
673- return isValidVecType (extract.getResult ().getType (), size);
725+ if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
726+ Type type = extract.getResult ().getType ();
727+ LLVM_DEBUG (llvm::dbgs () << " Checking if type " << type << " with size "
728+ << size << " can be vectorized: " );
729+ return checkRes (isValidVecType (type, size));
730+ }
674731
675732 LLVM_DEBUG (llvm::dbgs () << " Unsupported op " << op->getName () << " \n " );
676733 return false ;
@@ -903,12 +960,19 @@ SLPGraph::vectorize(IRRewriter &rewriter,
903960 for (auto *operand : node->operands )
904961 size = std::min (size, operand->size ());
905962
906- node->ops .resize (size);
963+ if (size < node->size ()) {
964+ LLVM_DEBUG (llvm::dbgs ()
965+ << " Size mismatch, resizing node with " << node->size ()
966+ << " operations to " << size << " \n " );
967+ node->ops .resize (size);
968+ }
907969
908970 while (node->size () > 1 ) {
909971 if (checkOpVecType (node, isValidVecType))
910972 break ;
911973
974+ LLVM_DEBUG (llvm::dbgs () << " No a valid vector type, popping back op: "
975+ << node->ops .back ()->getName () << " \n " );
912976 node->ops .pop_back ();
913977 }
914978 }
@@ -975,24 +1039,22 @@ SLPGraph::vectorize(IRRewriter &rewriter,
9751039 numElements, 1 );
9761040 };
9771041
978- if (auto load = dyn_cast<memref::LoadOp>(op)) {
979- auto vecType =
980- VectorType::get (numElements, load.getMemRefType ().getElementType ());
981- Value result = rewriter.create <vector::LoadOp>(
982- loc, vecType, load.getMemRef (), load.getIndices ());
983- mapping.map (load.getResult (), result);
1042+ if (maybeReadOp (op)) {
1043+ auto vecType = VectorType::get (numElements, getElementType (op));
1044+ Value result = rewriter.create <vector::LoadOp>(loc, vecType, getBase (op),
1045+ getIndices (op));
1046+ mapping.map (op->getResult (0 ), result);
9841047 handleNonVectorOutputs (result);
985- } else if (auto store = dyn_cast<memref::StoreOp> (op)) {
986- handleNonVectorInputs (store. getValueToStore ());
987- Value val = mapping.lookupOrDefault (store. getValueToStore ());
1048+ } else if (maybeWriteOp (op)) {
1049+ handleNonVectorInputs (getValueToStore (op ));
1050+ Value val = mapping.lookupOrDefault (getValueToStore (op ));
9881051 val = handleVecSizeMismatch (val);
989- rewriter.create <vector::StoreOp>(loc, val, store.getMemRef (),
990- store.getIndices ());
1052+ rewriter.create <vector::StoreOp>(loc, val, getBase (op), getIndices (op));
9911053 } else if (isVectorizable (op)) {
9921054 handleNonVectorInputs (op->getOperands ());
9931055 Operation *newOp = rewriter.clone (*op, mapping);
994- auto resVectorType =
995- VectorType::get (numElements, op-> getResultTypes (). front () );
1056+ Type resType = getElementTypeOrSelf (op-> getResultTypes (). front ());
1057+ auto resVectorType = VectorType::get (numElements, resType );
9961058
9971059 {
9981060 OpBuilder::InsertionGuard guard (rewriter);
0 commit comments