diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td index 7436998749791..970e488d3494d 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td @@ -34,4 +34,24 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func ]; } +def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> { + let summary = "SLP Vectorizer Pass"; + let description = [{ + This pass implements the SLP (Superword Level Parallelism) vectorizer. + It detects consecutive operations that can be put together into vector + operations. The pass works bi-directionaly, starting from reads or stores, + in search of scalars to combine. + + This is greedy vectorizer, it doesn't have any cost model (yet) and it tries + to create vector ops if we have at least 2 potential ops. + }]; + let dependentDialects = ["mlir::vector::VectorDialect"]; + + let options = [ + Option<"maxVectorBitwidth", "max-vector-bitwidth", "unsigned", + /*default=*/"std::numeric_limits::max()", + "Maximum supported vector bitwidth">, + ]; +} + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 8ca5cb6c6dfab..37333b739bd86 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorStep.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp + SLPVectorizer.cpp SubsetOpInterfaceImpl.cpp VectorDistribute.cpp VectorDropLeadUnitDim.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp new file mode 100644 index 0000000000000..58c4c5b271458 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp @@ -0,0 +1,1269 @@ +//===- SLPVectorizer.cpp - SLP Vectorizer Pass ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SLP vectorizer pass for MLIR. The pass attempts to +// combine similar independent operations into vector operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/SHA1.h" + +#define DEBUG_TYPE "slp-vectorizer" + +namespace mlir { +namespace vector { +#define GEN_PASS_DEF_GREEDYSLPVECTORIZER +#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" +} // namespace vector +} // namespace mlir + +using namespace mlir; +using namespace mlir::vector; + +namespace { +/// A group of consecutive memory operations of the same type (load or store) +/// that can potentially be vectorized together. +struct MemoryOpGroup { + enum class Type { Load, Store }; + Type type; + SmallVector ops; + int64_t elementsCount = 0; + + MemoryOpGroup(Type t) : type(t) {} + + bool isLoadGroup() const { return type == Type::Load; } + bool isStoreGroup() const { return type == Type::Store; } + + size_t opsCount() const { return ops.size(); } +}; + +static bool maybeReadOp(Operation *op) { + auto effectInterface = dyn_cast(op); + if (!effectInterface) + return true; + + return effectInterface.hasEffect(); +} + +static bool maybeWriteOp(Operation *op) { + auto effectInterface = dyn_cast(op); + if (!effectInterface) + return true; + + return effectInterface.hasEffect(); +} + +static std::optional> +getVectorElementTypeAndCount(VectorType vectorType) { + if (vectorType.getRank() > 1 || vectorType.isScalable()) + return std::nullopt; + + return std::make_pair(vectorType.getElementType(), + vectorType.getNumElements()); +} + +static std::optional> +getElementTypeAndCount(Operation *op) { + assert(op && "null op"); + if (auto loadOp = dyn_cast(op)) + return std::make_pair(loadOp.getResult().getType(), 1); + if (auto storeOp = dyn_cast(op)) + return std::make_pair(storeOp.getValueToStore().getType(), 1); + if (auto loadOp = dyn_cast(op)) + return getVectorElementTypeAndCount(loadOp.getVectorType()); + if (auto storeOp = dyn_cast(op)) + return getVectorElementTypeAndCount(storeOp.getVectorType()); + + return std::nullopt; +} + +static bool isSupportedMemOp(Operation *op) { + assert(op && "null op"); + auto typeAndCount = getElementTypeAndCount(op); + if (!typeAndCount) + return false; + + return isa_and_present( + typeAndCount->first); +} + +/// Collect all memory operations in the block into groups. +/// Each group contains either all loads or all stores, uninterrupted by +/// operations of the other type. +static SmallVector collectMemoryOpGroups(Block &block) { + SmallVector groups; + MemoryOpGroup *currentGroup = nullptr; + + for (Operation &op : block) { + // Check if current group is interrupted by a read or write op. + if (currentGroup) { + if (currentGroup->isLoadGroup() && maybeWriteOp(&op)) { + currentGroup = nullptr; + } else if (currentGroup->isStoreGroup() && maybeReadOp(&op)) { + currentGroup = nullptr; + } + } + + if (!isSupportedMemOp(&op)) + continue; + + bool isLoad = maybeReadOp(&op); + MemoryOpGroup::Type type = + isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store; + + if (!currentGroup) { + groups.emplace_back(type); + currentGroup = &groups.back(); + } + + currentGroup->ops.push_back(&op); + } + + return groups; +} + +static Value getBase(Operation *op) { + assert(op && "null op"); + if (auto loadOp = dyn_cast(op)) + return loadOp.getMemRef(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getMemRef(); + if (auto loadOp = dyn_cast(op)) + return loadOp.getBase(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getBase(); + + llvm_unreachable("unsupported op"); +} + +static Value getValueToStore(Operation *op) { + assert(op && "null op"); + if (auto storeOp = dyn_cast(op)) + return storeOp.getValueToStore(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getValueToStore(); + + llvm_unreachable("unsupported op"); +} + +static bool isContiguousLastDim(Value val) { + auto memrefType = dyn_cast(val.getType()); + if (!memrefType) + return false; + + int64_t offset; + SmallVector strides; + if (failed(memrefType.getStridesAndOffset(strides, offset))) + return false; + + return !strides.empty() && strides.back() == 1; +} + +static ValueRange getIndices(Operation *op) { + assert(op && "null op"); + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndices(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndices(); + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndices(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndices(); + + llvm_unreachable("unsupported op"); +} + +static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) { + auto applyOp1 = idx1.getDefiningOp(); + if (!applyOp1) + return false; + + auto applyOp2 = idx2.getDefiningOp(); + if (!applyOp2) + return false; + + if (applyOp1.getOperands() != applyOp2.getOperands()) + return false; + + AffineExpr expr1 = applyOp1.getAffineMap().getResult(0); + AffineExpr expr2 = applyOp2.getAffineMap().getResult(0); + auto diff = + simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size()); + + auto diffConst = dyn_cast(diff); + return diffConst && diffConst.getValue() == offset; +} + +/// Check if two indices are consecutive, i.e index1 + offset == index2. +static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) { + if (auto c1 = getConstantIntValue(idx1)) { + if (auto c2 = getConstantIntValue(idx2)) + return *c1 + offset == *c2; + } + + if (auto addOp2 = idx2.getDefiningOp()) { + if (addOp2.getLhs() == idx1 && + getConstantIntValue(addOp2.getRhs()) == offset) + return true; + + if (auto addOp1 = idx1.getDefiningOp()) { + if (addOp1.getLhs() == addOp2.getLhs() && + isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs(), offset)) + return true; + } + } + + if (isAdjacentAffineMapIndices(idx1, idx2, offset)) + return true; + + return false; +} + +/// Check if two ranges of indices are consecutive, i.e fastest index differs +/// by `offset` and all other indices are the same. +static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2, + int64_t offset) { + if (idx1.empty() || idx1.size() != idx2.size()) + return false; + + if (idx1.drop_back() != idx2.drop_back()) + return false; + + return isAdjacentIndices(idx1.back(), idx2.back(), offset); +} + +/// Check if two operations are adjacent and can be combined into a vector op. +/// This is done by checking if the base memrefs are the same, the last +/// dimension is contiguous, and the element types and indices are compatible. +/// If source read/write is already vectorized, only merge ops if vector +/// elements count is the same. +static bool isAdjacentOps(Operation *op1, Operation *op2) { + assert(op1 && "null op1"); + assert(op2 && "null op2"); + + Value base1 = getBase(op1); + Value base2 = getBase(op2); + if (base1 != base2) + return false; + + if (!isContiguousLastDim(base1)) + return false; + + auto typeAndCount1 = getElementTypeAndCount(op1); + if (!typeAndCount1) + return false; + + auto typeAndCount2 = getElementTypeAndCount(op2); + if (!typeAndCount2) + return false; + + if (typeAndCount1 != typeAndCount2) + return false; + + // For now we are only merging ops with same elements count. + return isAdjacentIndices(getIndices(op1), getIndices(op2), + typeAndCount1->second); +} + +// Extract contiguous groups from a MemoryOpGroup +static SmallVector +extractContiguousGroups(const MemoryOpGroup &group) { + SmallVector result; + if (group.ops.empty()) + return result; + + llvm::SmallDenseSet processedOps; + + for (Operation *op : group.ops) { + if (processedOps.contains(op)) + continue; + + // Start a new group with this operation + result.emplace_back(group.type); + MemoryOpGroup ¤tGroup = result.back(); + currentGroup.elementsCount = getElementTypeAndCount(op)->second; + auto ¤tOps = currentGroup.ops; + currentOps.push_back(op); + processedOps.insert(op); + + // Keep adding ops to the beginning or end of the current group until no + // more ops can be added. + bool foundMore; + do { + foundMore = false; + for (Operation *otherOp : group.ops) { + if (processedOps.contains(otherOp)) + continue; + + Operation *firstOp = currentOps.front(); + Operation *lastOp = currentOps.back(); + if (isAdjacentOps(otherOp, firstOp)) { + currentOps.insert(currentOps.begin(), otherOp); + processedOps.insert(otherOp); + foundMore = true; + } else if (isAdjacentOps(lastOp, otherOp)) { + currentOps.push_back(otherOp); + processedOps.insert(otherOp); + foundMore = true; + } + } + } while (foundMore); + + if (currentOps.size() <= 1) { + // Do not vectorize if there is only one op. + result.pop_back(); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with " + << currentGroup.opsCount() << " operations\n"); + } + return result; +} + +/// Check if an operation is vectorizable. +/// If `expectedElementsCount` is provided, check if original op had the +/// specified number of elements. +static bool +isVectorizable(Operation *op, + std::optional expectedElementsCount = std::nullopt) { + if (!OpTrait::hasElementwiseMappableTraits(op)) + return false; + + if (op->getNumResults() != 1) + return false; + + for (auto type : + llvm::concat(op->getResultTypes(), op->getOperandTypes())) { + int64_t vectorElementsCount = 1; + if (auto vectorType = dyn_cast(type)) { + if (vectorType.getRank() > 1 || vectorType.isScalable()) + return false; + + type = vectorType.getElementType(); + vectorElementsCount = vectorType.getNumElements(); + } + + if (expectedElementsCount && vectorElementsCount != *expectedElementsCount) + return false; + + if (!isa(type)) + return false; + } + + return true; +} + +/// Get the next operation in the block, assuming `op` is not a terminator/last +/// operation in the block. +static Operation *nextOp(Operation *op) { + assert(op && "null op"); + auto it = op->getIterator(); + return &*std::next(it); +} + +/// A node in the SLP graph representing a group of vectorizable operations +struct SLPGraphNode { + SmallVector ops; + SmallVector users; + SmallVector operands; + Operation *insertionPoint = nullptr; + int64_t elementsCount = 0; + bool isRoot = false; + + SLPGraphNode() = default; + SLPGraphNode(ArrayRef operations) + : ops(operations.begin(), operations.end()) {} + + size_t opsCount() const { return ops.size(); } + size_t vectorSize() const { return elementsCount * opsCount(); } + + Operation *op() const { + assert(!ops.empty() && "empty ops"); + return ops.front(); + } + + /// Get the suitable insertion point for the new vectorized op. + /// This method is trying to take into account operands insertions points too + /// to satisfy dominance relations. + Operation *getInsertionPoint() { + assert(!ops.empty() && "empty node"); + if (insertionPoint) + return insertionPoint; + + // Find the toplogically first node, which is not nessesary the first in the + // `ops` as `ops` are sorted by their position in vector. + Operation *ret = op(); + for (Operation *op : ArrayRef(ops).drop_front()) { + if (op->isBeforeInBlock(ret)) + ret = op; + } + + for (Operation *op : ops) { + for (Value opOperand : op->getOperands()) { + Operation *defOp = opOperand.getDefiningOp(); + if (!defOp || defOp->getBlock() != ret->getBlock()) + continue; + + Operation *next = nextOp(defOp); + if (ret->isBeforeInBlock(next)) + ret = next; + } + } + + // Try to adjust insertion point to satisfy dominance relations with + // operands. + for (SLPGraphNode *operand : operands) { + Operation *ip = operand->getInsertionPoint(); + if (!ip) + return nullptr; + + Operation *next = nextOp(ip); + if (next->getBlock() == ret->getBlock() && ret->isBeforeInBlock(next)) + ret = next; + } + + insertionPoint = ret; + return ret; + } +}; + +/// A graph of vectorizable operations +class SLPGraph { +public: + SLPGraph() = default; + ~SLPGraph() = default; + + SLPGraph(const SLPGraph &) = delete; + SLPGraph &operator=(const SLPGraph &) = delete; + + SLPGraph(SLPGraph &&) = default; + SLPGraph &operator=(SLPGraph &&) = default; + + /// Add a new node to the graph + SLPGraphNode *addNode(ArrayRef operations, + int64_t elementsCount) { + nodes.push_back(std::make_unique(operations)); + auto *node = nodes.back().get(); + node->elementsCount = elementsCount; + for (Operation *op : operations) + opToNode[op] = node; + return node; + } + + /// Add a root node (memory operation) + SLPGraphNode *addRoot(ArrayRef operations, + int64_t elementsCount) { + auto *node = addNode(operations, elementsCount); + node->isRoot = true; + return node; + } + + /// Add a dependency edge between nodes + void addEdge(SLPGraphNode *from, SLPGraphNode *to) { + from->users.push_back(to); + to->operands.push_back(from); + } + + /// Get all root nodes + SmallVector getRoots() const { + SmallVector roots; + for (const auto &node : nodes) + if (node->isRoot) + roots.push_back(node.get()); + return roots; + } + + /// Get the node associated with an operation + SLPGraphNode *getNodeForOp(Operation *op) const { + auto it = opToNode.find(op); + return it != opToNode.end() ? it->second : nullptr; + } + + /// Topologically sort the nodes in the graph + SmallVector topologicalSort() const { + SmallVector result; + llvm::SmallDenseSet visited; + + SmallVector stack; + + // Process each node + for (const auto &node : nodes) { + if (visited.contains(node.get())) + continue; + + stack.emplace_back(node.get()); + while (!stack.empty()) { + SLPGraphNode *node = stack.pop_back_val(); + if (visited.contains(node)) + continue; + + stack.push_back(node); + + bool pushed = false; + for (SLPGraphNode *operand : node->operands) { + if (visited.contains(operand)) + continue; + + stack.push_back(operand); + pushed = true; + } + + if (!pushed) { + visited.insert(node); + result.push_back(node); + } + } + } + + return result; + } + + /// Vectorize the operations in the graph. + /// Returns number of nodes vectorized or failure if failed. + FailureOr + vectorize(IRRewriter &rewriter, + llvm::function_ref isValidVecType); + + /// Print the graph structure + [[maybe_unused]] void print() const { + llvm::dbgs() << "SLP Graph Structure:\n"; + llvm::dbgs() << "===================\n"; + + // First print all roots + llvm::dbgs() << "Roots:\n"; + for (const auto &node : nodes) { + if (!node->isRoot) + continue; + llvm::dbgs() << " " << (maybeReadOp(node->op()) ? "LOAD" : "STORE") + << " group with " << node->opsCount() << " operations:\n"; + for (auto *op : node->ops) { + llvm::dbgs() << " " << *op << "\n"; + } + llvm::dbgs() << " Users: "; + for (auto *user : node->users) { + llvm::dbgs() << "\n Group with " << user->opsCount() + << " operations:"; + for (auto *op : user->ops) { + llvm::dbgs() << "\n " << *op; + } + } + llvm::dbgs() << "\n"; + } + + // Then print all non-root nodes + llvm::dbgs() << "\nNon-root nodes:\n"; + for (const auto &node : nodes) { + if (node->isRoot) + continue; + llvm::dbgs() << " Group with " << node->opsCount() << " operations:\n"; + for (auto *op : node->ops) { + llvm::dbgs() << " " << *op << "\n"; + } + llvm::dbgs() << " Operands: "; + for (auto *operand : node->operands) { + llvm::dbgs() << "\n Group with " << operand->opsCount() + << " operations:"; + for (auto *op : operand->ops) { + llvm::dbgs() << "\n " << *op; + } + } + llvm::dbgs() << "\n Users: "; + for (auto *user : node->users) { + llvm::dbgs() << "\n Group with " << user->opsCount() + << " operations:"; + for (auto *op : user->ops) { + llvm::dbgs() << "\n " << *op; + } + } + llvm::dbgs() << "\n"; + } + llvm::dbgs() << "===================\n"; + } + +private: + SmallVector> nodes; + llvm::SmallDenseMap opToNode; +}; + +/// This pass implements the greedy SLP vectorizer. It detects consecutive +/// operations that can be put together into vector operations. The pass works +/// bi-directionaly, starting from reads or stores, in search of scalars to +/// combine. +/// +/// Pass is split into multiple steps: +/// 1. Collect memory operation groups within same block. +/// Group is either multiple loads uninterrupted by stores or multiple stores +/// uninterrupted by loads. +/// +/// 2. Extract contiguous groups from memory operation groups, based on the +/// ops base memrefs, load/store element types, and indices. +/// +/// 3. Build SLP graph from contiguous groups. This is done by going both +/// top-down and bottom-up through uses/operands respectively, starting from +/// contiguous memory operation groups. +/// +/// 4. Vectorize SLP graph. This is done by topological sort of the graph and +/// vectorizing each node in the order of the sort. +/// +/// Vectorization is done by cloning the operations and mapping the operands and +/// results. +struct GreedySLPVectorizerPass + : public mlir::vector::impl::GreedySLPVectorizerBase< + GreedySLPVectorizerPass> { + using GreedySLPVectorizerBase::GreedySLPVectorizerBase; + + void runOnOperation() override; +}; + +using Fingerprint = std::array; + +template +static void addDataToHash(llvm::SHA1 &hasher, const T &data) { + hasher.update( + ArrayRef(reinterpret_cast(&data), sizeof(T))); +} + +/// SLP vectorizer is bi-directional, so when we go top-down we can can have +/// multiple users with the same immediate op type, this class tries to compute +/// fingerprint for such ops based on the entire ops graph to maximize further +/// scalar ops merging. +/// +/// Example: +/// ``` +/// %0 = memref.load %arg0[%c0] : memref<8xi32> +/// %1 = memref.load %arg0[%c1] : memref<8xi32> +/// %2 = memref.load %arg0[%c2] : memref<8xi32> +/// %3 = memref.load %arg0[%c3] : memref<8xi32> +/// +/// %4 = memref.load %arg1[%c0] : memref<8xi32> +/// %5 = memref.load %arg1[%c1] : memref<8xi32> +/// %6 = memref.load %arg1[%c2] : memref<8xi32> +/// %7 = memref.load %arg1[%c3] : memref<8xi32> +/// +/// %8 = arith.addi %0, %4 : i32 +/// %12 = arith.addi %0, %arg2 : i32 +/// +/// %13 = arith.addi %1, %arg3 : i32 +/// %9 = arith.addi %1, %5 : i32 +/// +/// %10 = arith.addi %2, %6 : i32 +/// %14 = arith.addi %2, %arg4 : i32 +/// +/// %15 = arith.addi %3, %arg5 : i32 +/// %11 = arith.addi %3, %7 : i32 +/// ``` +/// Here each load have multiple uses, in different order, and we want to merge +/// them in a way that maximizes the number of merged ops. +/// +/// To achieve this, we compute fingerprint for each op including the other +/// operands, which will include the other loads in this example. +struct OperationsFingerprint { + OperationsFingerprint(const SLPGraph &graph) : graph(graph) {} + + Fingerprint getFingerprint(Operation *op) { + assert(op && "null op"); + auto it = fingerprints.find(op); + if (it != fingerprints.end()) + return it->second; + + SmallVector worklist; + SmallVector toposortedOps; + worklist.emplace_back(op); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + toposortedOps.emplace_back(op); + if (graph.getNodeForOp(op)) + continue; + + for (Value operand : op->getOperands()) { + auto *defOp = operand.getDefiningOp(); + if (!defOp || !isVectorizable(defOp)) + continue; + + toposortedOps.emplace_back(defOp); + worklist.emplace_back(defOp); + } + } + + for (Operation *op : llvm::reverse(toposortedOps)) { + llvm::SHA1 hasher; + addDataToHash(hasher, op->getName().getTypeID()); + addDataToHash(hasher, op->getRawDictionaryAttrs()); + addDataToHash(hasher, op->hashProperties()); + for (Value operand : op->getOperands()) { + auto *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + + auto *node = graph.getNodeForOp(defOp); + if (node) { + addDataToHash(hasher, node); + continue; + } + + auto it2 = fingerprints.find(defOp); + if (it2 != fingerprints.end()) { + addDataToHash(hasher, it2->second); + continue; + } + } + fingerprints[op] = hasher.result(); + } + + return fingerprints[op]; + } + + void invalidate(Operation *op) { + if (fingerprints.contains(op)) + fingerprints.clear(); + } + + const SLPGraph &graph; + DenseMap fingerprints; +}; + +/// Check if op input/output types can be vectorized. +static bool +checkOpVecType(SLPGraphNode *node, + llvm::function_ref isValidVecType) { + Operation *op = node->op(); + size_t size = node->vectorSize(); + auto checkRes = [](bool res) -> bool { + LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n"); + return res; + }; + + if (auto typeAndCount = getElementTypeAndCount(op)) { + Type elementType = typeAndCount->first; + LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType + << " with size " << size << " can be vectorized: "); + return checkRes(isValidVecType(elementType, size)); + } + + if (isVectorizable(op)) { + for (auto type : + llvm::concat(op->getResultTypes(), op->getOperandTypes())) { + Type elementType = getElementTypeOrSelf(type); + LLVM_DEBUG(llvm::dbgs() + << "Checking if type " << elementType << " with size " << size + << " can be vectorized: "); + if (!checkRes(isValidVecType(elementType, size))) + return false; + } + return true; + } + + if (auto extract = dyn_cast(op)) { + Type type = extract.getResult().getType(); + LLVM_DEBUG(llvm::dbgs() << "Checking if type " << type << " with size " + << size << " can be vectorized: "); + return checkRes(isValidVecType(type, size)); + } + + LLVM_DEBUG(llvm::dbgs() << "Unsupported op " << op->getName() << "\n"); + return false; +} + +/// Check if two ops are equivalent for the purposes of SLP vectorization, i.e. +/// they can be merged into single vector op. +static bool isEquivalent(Operation *op1, Operation *op2) { + assert(op1 && "null op1"); + assert(op2 && "null op2"); + if (op1 == op2) + return true; + + if (op1->getName() != op2->getName()) + return false; + + if (op1->getAttrs() != op2->getAttrs()) + return false; + + if (op1->getBlock() != op2->getBlock()) + return false; + + return true; +} + +/// Get static position of the extract op, if it is 1D and static. +static std::optional getExtractIndex(vector::ExtractOp extractOp) { + if (extractOp.getNumIndices() != 1 || extractOp.hasDynamicPosition()) + return std::nullopt; + + return extractOp.getStaticPosition().front(); +} + +/// Build the SLP graph starting from memory operation groups and going both +/// top-down and bottom-up through uses/operands respectively. +static SLPGraph buildSLPGraph(ArrayRef rootGroups) { + if (rootGroups.empty()) + return SLPGraph(); + + LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size() + << " root groups ===\n"); + SLPGraph graph; + + SmallVector worklist; + + // First, create nodes for each contiguous memory operation group + for (const auto &group : rootGroups) { + auto *node = graph.addRoot(group.ops, group.elementsCount); + worklist.push_back(node); + + LLVM_DEBUG({ + llvm::dbgs() << "Created root group node with " << node->opsCount() + << " operations of type " + << (group.isLoadGroup() ? "Load" : "Store") << "\n"; + }); + } + + OperationsFingerprint fingerprints(graph); + + // Process node uses, going top-down. + auto processUse = [&](SLPGraphNode *node, OpOperand &use) { + Operation *user = use.getOwner(); + auto *existingNode = graph.getNodeForOp(user); + if (existingNode) { + LLVM_DEBUG(llvm::dbgs() << " Adding edge from " << node->op()->getName() + << " to " << user->getName() << "\n"); + graph.addEdge(node, existingNode); + return; + } + + if (!isVectorizable(user, node->elementsCount)) + return; + + Fingerprint expectedFingerprint = fingerprints.getFingerprint(user); + + SmallVector currentOps; + currentOps.emplace_back(user); + for (Operation *op : ArrayRef(node->ops).drop_front()) { + Operation *found = nullptr; + for (OpOperand &opUse : op->getUses()) { + if (opUse.getOperandNumber() != use.getOperandNumber()) + continue; + + Operation *useOwner = opUse.getOwner(); + if (!isEquivalent(useOwner, user) || + fingerprints.getFingerprint(useOwner) != expectedFingerprint) + continue; + + found = useOwner; + break; + } + if (!found) + break; + + currentOps.push_back(found); + } + + if (currentOps.size() == 1) + return; + + auto *newNode = graph.addNode(currentOps, node->elementsCount); + graph.addEdge(node, newNode); + for (Operation *op : currentOps) + fingerprints.invalidate(op); + + worklist.push_back(newNode); + }; + + // Process node operands, going bottom-up. + auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) { + Operation *srcOp = operand.getDefiningOp(); + if (!srcOp) + return; + + auto *existingNode = graph.getNodeForOp(srcOp); + if (existingNode) { + LLVM_DEBUG(llvm::dbgs() << " Adding edge from " << srcOp->getName() + << " to " << node->op()->getName() << "\n"); + graph.addEdge(existingNode, node); + return; + } + + SmallVector currentOps; + if (auto extractOp = dyn_cast(srcOp)) { + LLVM_DEBUG(llvm::dbgs() + << " Processing vector.extract op with index " + << getExtractIndex(extractOp).value_or(-1) << "\n"); + currentOps.push_back(extractOp); + + std::optional extractIndex = getExtractIndex(extractOp); + if (!extractIndex) + return; + + Value vector = extractOp.getVector(); + int64_t currentIndex = *extractIndex; + for (Operation *op : ArrayRef(node->ops).drop_front()) { + auto otherOp = op->getOperand(index).getDefiningOp(); + if (!otherOp || otherOp.getVector() != vector) + break; + + std::optional otherExtractIndex = getExtractIndex(otherOp); + if (!otherExtractIndex || *otherExtractIndex != (currentIndex + 1)) + break; + + currentOps.push_back(otherOp); + ++currentIndex; + } + } else if (isVectorizable(srcOp, node->elementsCount)) { + LLVM_DEBUG(llvm::dbgs() << " Processing vectorizable op " + << srcOp->getName() << "\n"); + + currentOps.emplace_back(srcOp); + for (Operation *op : ArrayRef(node->ops).drop_front()) { + Operation *otherOp = op->getOperand(index).getDefiningOp(); + if (!otherOp || !isEquivalent(otherOp, srcOp)) + break; + + currentOps.push_back(otherOp); + } + } else { + LLVM_DEBUG(llvm::dbgs() + << " Unsupported op " << srcOp->getName() << "\n"); + return; + } + + if (currentOps.size() == 1) + return; + + auto *newNode = graph.addNode(currentOps, node->elementsCount); + graph.addEdge(newNode, node); + for (Operation *op : currentOps) + fingerprints.invalidate(op); + + worklist.push_back(newNode); + }; + + while (!worklist.empty()) { + SLPGraphNode *node = worklist.pop_back_val(); + LLVM_DEBUG(llvm::dbgs() + << "Processing node with " << node->opsCount() + << " operations, first op: " << node->op()->getName() << "\n"); + + Operation *op = node->op(); + for (OpOperand &use : op->getUses()) + processUse(node, use); + + for (auto [i, operand] : llvm::enumerate(op->getOperands())) + processOperands(node, operand, i); + } + + return graph; +} + +FailureOr +SLPGraph::vectorize(IRRewriter &rewriter, + llvm::function_ref isValidVecType) { + if (nodes.empty()) + return 0; + + LLVM_DEBUG(llvm::dbgs() << "Vectorizing SLP graph with " << nodes.size() + << " nodes\n"); + + // Get topologically sorted nodes + SmallVector sortedNodes = topologicalSort(); + if (sortedNodes.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n"); + return failure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Topologically sorted nodes:\n"; + for (auto *node : sortedNodes) { + llvm::dbgs() << " Node with " << node->opsCount() + << " operations: " << node->op()->getName() << "\n"; + } + }); + + auto isBadNode = [&](SLPGraphNode *node) { + // Do not vectorize stray nodes which are not connected to any other + // nodes. + return (node->users.empty() && node->operands.empty()) || + node->opsCount() <= 1; + }; + + // Update node vec sizes if its inputs vec sizes are smaller. + // This is nedeed to handle situations when we have 3->3->4 sizes in tree. + // TODO: It maybe possible to reconstruct the larger vec size combining src + // smaller vector and scalar arg. + for (auto *node : sortedNodes) { + size_t size = node->opsCount(); + for (auto *operand : node->operands) + size = std::min(size, operand->opsCount()); + + if (size < node->opsCount()) { + LLVM_DEBUG(llvm::dbgs() + << "Size mismatch, resizing node with " << node->opsCount() + << " operations to " << size << "\n"); + node->ops.resize(size); + } + + while (node->opsCount() > 1) { + if (checkOpVecType(node, isValidVecType)) + break; + + LLVM_DEBUG(llvm::dbgs() << "No a valid vector type, popping back op: " + << node->ops.back()->getName() << "\n"); + node->ops.pop_back(); + } + } + + llvm::erase_if(sortedNodes, isBadNode); + + IRMapping mapping; + for (auto *node : sortedNodes) { + LLVM_DEBUG({ + llvm::dbgs() << "Processing node with " << node->opsCount() + << " operations\n"; + llvm::dbgs() << " First op: " << *node->op() << "\n"; + }); + + // `op` is the node with the smallest index in vector and not the + // nessesarily the good insertion point. + Operation *op = node->op(); + Operation *ip = node->getInsertionPoint(); + if (!ip) + return op->emitError("no insertion point found for node"); + + LLVM_DEBUG(llvm::dbgs() << " Insertion point: " << *ip << "\n"); + + rewriter.setInsertionPoint(ip); + int64_t numElements = node->vectorSize(); + Location loc = op->getLoc(); + + auto handleNonVectorInputs = [&](ValueRange operands) { + // Handle the case when op operands are not vectorized or have smaller + // vector size, construct the vector from the scalar operands using + // FromElementsOp. + for (auto [i, operand] : llvm::enumerate(operands)) { + if (getNodeForOp(operand.getDefiningOp())) + continue; + + SmallVector args; + for (Operation *defOp : node->ops) { + Value arg = defOp->getOperand(i); + if (auto vecType = dyn_cast(arg.getType())) { + assert(vecType.getRank() == 1); + for (auto j : llvm::seq(vecType.getNumElements())) + args.push_back(rewriter.create(loc, arg, j)); + + } else { + args.push_back(arg); + } + } + + auto vecType = VectorType::get(numElements, + getElementTypeOrSelf(operand.getType())); + Value vector = + rewriter.create(loc, vecType, args); + mapping.map(operand, vector); + } + }; + + auto handleNonVectorOutputs = [&](Value newResult, + Type originalResultType) { + // Handle the case when op results are not vectorized or have smaller + // vector size, extract the elements from the vector. + for (auto [i, result] : llvm::enumerate(node->ops)) { + for (OpOperand &use : result->getUses()) { + Operation *useOwner = use.getOwner(); + if (getNodeForOp(useOwner)) + continue; + + int64_t offset = i * node->elementsCount; + Value elem; + + if (auto vecType = dyn_cast(originalResultType)) { + assert(vecType.getRank() <= 1); + if (vecType.getRank() == 0) { + elem = rewriter.create(loc, newResult, offset); + elem = rewriter.create(loc, vecType, elem); + } else { + elem = rewriter.create( + loc, newResult, offset, vecType.getNumElements(), 1); + } + } else { + elem = rewriter.create(loc, newResult, offset); + } + + use.set(elem); + } + } + }; + + auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value { + // Handle vector size misamatch between 2 vectorized nodes. + auto srcType = cast(arg.getType()); + assert(srcType.getRank() == 1); + if (srcType.getDimSize(0) == numElements) + return arg; + + return rewriter.create(loc, arg, offset, + numElements, 1); + }; + + if (maybeReadOp(op)) { + auto vecType = + VectorType::get(numElements, getElementTypeAndCount(op)->first); + Value result = rewriter.create(loc, vecType, getBase(op), + getIndices(op)); + Value originalResult = op->getResult(0); + mapping.map(originalResult, result); + handleNonVectorOutputs(result, originalResult.getType()); + } else if (maybeWriteOp(op)) { + handleNonVectorInputs(getValueToStore(op)); + Value val = mapping.lookupOrDefault(getValueToStore(op)); + val = handleVecSizeMismatch(val); + rewriter.create(loc, val, getBase(op), getIndices(op)); + } else if (isVectorizable(op)) { + handleNonVectorInputs(op->getOperands()); + Operation *newOp = rewriter.clone(*op, mapping); + Type resType = getElementTypeOrSelf(op->getResultTypes().front()); + auto resVectorType = VectorType::get(numElements, resType); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(newOp); + for (OpOperand &operand : newOp->getOpOperands()) { + Value newOperand = handleVecSizeMismatch(operand.get()); + operand.set(newOperand); + } + } + newOp->getResult(0).setType(resVectorType); + + mapping.map(op->getResults(), newOp->getResults()); + handleNonVectorOutputs(newOp->getResult(0), op->getResultTypes().front()); + } else if (auto extract = dyn_cast(op)) { + // We alredy verified index is valid during graph construction, so + // do need to check `getExtractIndex` result. + int64_t offset = *getExtractIndex(extract); + Value val = handleVecSizeMismatch(extract.getVector(), offset); + mapping.map(extract.getResult(), val); + } else { + op->emitError("unsupported operation"); + return failure(); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n"); + + // As all nodes were cloned, we need to erase the original ops in reverse + // topo order to avoid invalidation users. + for (auto *node : llvm::reverse(sortedNodes)) { + for (Operation *op : node->ops) { + LLVM_DEBUG(llvm::dbgs() << "Erasing op: " << *op << "\n"); + rewriter.eraseOp(op); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n"); + return sortedNodes.size(); +} + +/// Try to vectorize ops in a block. +/// Returns number of nodes vectorized or error flag if failed. +static FailureOr +tryToVectorizeInBlock(Block &block, + llvm::function_ref isValidVecType) { + LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: " + << block.getParentOp()->getName() << "\n"); + + // Collect memory operation groups + SmallVector groups = collectMemoryOpGroups(block); + + // Process each group to find contiguous sequences + SmallVector rootGroups; + for (const auto &group : groups) { + SmallVector contiguousGroups = + extractContiguousGroups(group); + LLVM_DEBUG({ + llvm::dbgs() << "Found " << contiguousGroups.size() + << " contiguous groups in " + << (group.isLoadGroup() ? "load" : "store") << " group\n"; + for (const auto &contigGroup : contiguousGroups) { + llvm::dbgs() << " Contiguous group with " << contigGroup.opsCount() + << " operations\n"; + } + }); + rootGroups.append(contiguousGroups); + } + + // Build the SLP graph from root groups + SLPGraph graph = buildSLPGraph(rootGroups); + LLVM_DEBUG(graph.print()); + + // Vectorize the graph + IRRewriter rewriter(block.getParentOp()->getContext()); + FailureOr numNodesVectorized = + graph.vectorize(rewriter, isValidVecType); + if (failed(numNodesVectorized)) + LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n"); + + return numNodesVectorized; +} + +static bool isPow2(size_t size) { + assert(size > 0); + return (size & (size - 1)) == 0; +} + +void GreedySLPVectorizerPass::runOnOperation() { + Operation *op = getOperation(); + + const DataLayout *dataLayout = nullptr; + auto isValidVecType = [&](Type type, size_t count) { + if (!isPow2(count)) + return false; + + if (!dataLayout) + dataLayout = &getAnalysis().getAtOrAbove(op); + + auto sizeInBits = dataLayout->getTypeSizeInBits(type); + + return sizeInBits * count <= this->maxVectorBitwidth; + }; + + // Run until fixed point is reached. + bool changed; + do { + changed = false; + auto visitor = [&](Block *block) -> WalkResult { + FailureOr numNodesVectorized = + tryToVectorizeInBlock(*block, isValidVecType); + if (failed(numNodesVectorized)) + return WalkResult::interrupt(); + + changed = changed || *numNodesVectorized > 0; + return WalkResult::advance(); + }; + // Walk all blocks recursively + if (op->walk(visitor).wasInterrupted()) + return signalPassFailure(); + + // Run empty `applyPatternsGreedily` for simple DCE and folding. + if (changed) { + auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE(); + (void)applyPatternsGreedily(op, {}, config); + } + } while (changed); +} + +} // namespace diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir new file mode 100644 index 0000000000000..29c077d7ab34f --- /dev/null +++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir @@ -0,0 +1,913 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(greedy-slp-vectorizer{max-vector-bitwidth=256}))' | FileCheck %s + + +// CHECK-LABEL: func @negative_single_op +func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK-NOT: vector + %c0 = arith.constant 0 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %4 = memref.load %arg1[%c0] : memref<8xi32> + %8 = arith.addi %0, %4 : i32 + memref.store %8, %arg0[%c0] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_write +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + memref.store %0, %arg0[%c0] : memref<8xi32> + memref.store %1, %arg0[%c1] : memref<8xi32> + memref.store %2, %arg0[%c2] : memref<8xi32> + memref.store %3, %arg0[%c3] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_write_size_mistamtch +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES1:.*]] = vector.extract_strided_slice %[[RES]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: vector.store %[[RES1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + memref.store %0, %arg0[%c0] : memref<8xi32> + memref.store %1, %arg0[%c1] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_write_interleaved +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %3 = memref.load %arg0[%c3] : memref<8xi32> + %0 = memref.load %arg0[%c0] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + + memref.store %1, %arg0[%c1] : memref<8xi32> + memref.store %0, %arg0[%c0] : memref<8xi32> + memref.store %3, %arg0[%c3] : memref<8xi32> + memref.store %2, %arg0[%c2] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_write_add_index +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index) +func.func @read_write_add_index(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) { + // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32> + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %ind1 = arith.addi %arg2, %c1 : index + %ind2 = arith.addi %arg2, %c2 : index + %ind3 = arith.addi %arg2, %c3 : index + + %0 = memref.load %arg0[%arg2] : memref<8xi32> + %1 = memref.load %arg0[%ind1] : memref<8xi32> + %2 = memref.load %arg0[%ind2] : memref<8xi32> + %3 = memref.load %arg0[%ind3] : memref<8xi32> + + memref.store %0, %arg0[%arg2] : memref<8xi32> + memref.store %1, %arg0[%ind1] : memref<8xi32> + memref.store %2, %arg0[%ind2] : memref<8xi32> + memref.store %3, %arg0[%ind3] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_write_add_index_interleaved +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index) +func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) { + // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32> + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %ind1 = arith.addi %arg2, %c1 : index + %ind2 = arith.addi %arg2, %c2 : index + %ind3 = arith.addi %arg2, %c3 : index + + %0 = memref.load %arg0[%arg2] : memref<8xi32> + %1 = memref.load %arg0[%ind1] : memref<8xi32> + %3 = memref.load %arg0[%ind3] : memref<8xi32> + %2 = memref.load %arg0[%ind2] : memref<8xi32> + + memref.store %3, %arg0[%ind3] : memref<8xi32> + memref.store %0, %arg0[%arg2] : memref<8xi32> + memref.store %1, %arg0[%ind1] : memref<8xi32> + memref.store %2, %arg0[%ind2] : memref<8xi32> + + return +} + + +#map0 = affine_map<()[s0, s1] -> (s1 * s0)> +#map1 = affine_map<()[s0, s1] -> (s1 * s0 + 1)> +#map2 = affine_map<()[s0, s1] -> (s1 * s0 + 2)> +#map3 = affine_map<()[s0, s1] -> (s1 * s0 + 3)> + +// CHECK-LABEL: func @read_write_affine_apply +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +func.func @read_write_affine_apply(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index, %arg3: index) { + // CHECK: %[[IDX:.*]] = affine.apply #{{.*}}()[%[[ARG2]], %[[ARG3]]] + // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32> + + %ind0 = affine.apply #map0()[%arg2, %arg3] + %ind1 = affine.apply #map1()[%arg2, %arg3] + %ind2 = affine.apply #map2()[%arg2, %arg3] + %ind3 = affine.apply #map3()[%arg2, %arg3] + + %0 = memref.load %arg0[%ind0] : memref<8xi32> + %1 = memref.load %arg0[%ind1] : memref<8xi32> + %2 = memref.load %arg0[%ind2] : memref<8xi32> + %3 = memref.load %arg0[%ind3] : memref<8xi32> + + memref.store %0, %arg0[%ind0] : memref<8xi32> + memref.store %1, %arg0[%ind1] : memref<8xi32> + memref.store %2, %arg0[%ind2] : memref<8xi32> + memref.store %3, %arg0[%ind3] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: %[[R0:.*]] = vector.extract %[[RES]][0] : i32 from vector<4xi32> + // CHECK: %[[R1:.*]] = vector.extract %[[RES]][1] : i32 from vector<4xi32> + // CHECK: %[[R2:.*]] = vector.extract %[[RES]][2] : i32 from vector<4xi32> + // CHECK: %[[R3:.*]] = vector.extract %[[RES]][3] : i32 from vector<4xi32> + // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] : i32, i32, i32, i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + %8 = arith.addi %0, %4 : i32 + %9 = arith.addi %1, %5 : i32 + %10 = arith.addi %2, %6 : i32 + %11 = arith.addi %3, %7 : i32 + + return %8, %9, %10, %11 : i32, i32, i32, i32 +} + + +// CHECK-LABEL: func @add_write +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>) +func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, + %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, + %arg8: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : vector<4xi32> + // CHECK: %[[B:.*]] = vector.from_elements %[[ARG4]], %[[ARG5]], %[[ARG6]], %[[ARG7]] : vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %0, %1 : vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG8]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %8 = arith.addi %arg0, %arg4 : i32 + %9 = arith.addi %arg1, %arg5 : i32 + %10 = arith.addi %arg2, %arg6 : i32 + %11 = arith.addi %arg3, %arg7 : i32 + + memref.store %8, %arg8[%c0] : memref<8xi32> + memref.store %9, %arg8[%c1] : memref<8xi32> + memref.store %10, %arg8[%c2] : memref<8xi32> + memref.store %11, %arg8[%c3] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + %8 = arith.addi %0, %4 : i32 + %9 = arith.addi %1, %5 : i32 + %10 = arith.addi %2, %6 : i32 + %11 = arith.addi %3, %7 : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_vec_0d +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_vec_0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = vector.load %arg0[%c0] : memref<8xi32>, vector + %1 = vector.load %arg0[%c1] : memref<8xi32>, vector + %2 = vector.load %arg0[%c2] : memref<8xi32>, vector + %3 = vector.load %arg0[%c3] : memref<8xi32>, vector + + %4 = vector.load %arg1[%c0] : memref<8xi32>, vector + %5 = vector.load %arg1[%c1] : memref<8xi32>, vector + %6 = vector.load %arg1[%c2] : memref<8xi32>, vector + %7 = vector.load %arg1[%c3] : memref<8xi32>, vector + + %8 = arith.addi %0, %4 : vector + %9 = arith.addi %1, %5 : vector + %10 = arith.addi %2, %6 : vector + %11 = arith.addi %3, %7 : vector + + vector.store %8, %arg0[%c0] : memref<8xi32>, vector + vector.store %9, %arg0[%c1] : memref<8xi32>, vector + vector.store %10, %arg0[%c2] : memref<8xi32>, vector + vector.store %11, %arg0[%c3] : memref<8xi32>, vector + + return +} + + +// CHECK-LABEL: func @read_read_add_write_vec_1d +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_vec_1d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32> + %1 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32> + %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<1xi32> + %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32> + + %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32> + %5 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32> + %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<1xi32> + %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32> + + %8 = arith.addi %0, %4 : vector<1xi32> + %9 = arith.addi %1, %5 : vector<1xi32> + %10 = arith.addi %2, %6 : vector<1xi32> + %11 = arith.addi %3, %7 : vector<1xi32> + + vector.store %8, %arg0[%c0] : memref<8xi32>, vector<1xi32> + vector.store %9, %arg0[%c1] : memref<8xi32>, vector<1xi32> + vector.store %10, %arg0[%c2] : memref<8xi32>, vector<1xi32> + vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_mixed_vecs +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_mixed_vecs(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32> + + %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32> + + %8 = arith.addi %0, %4 : vector<2xi32> + %10 = arith.addi %2, %6 : i32 + %11 = arith.addi %3, %7 : vector<1xi32> + + vector.store %8, %arg0[%c0] : memref<8xi32>, vector<2xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_seven +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>) +func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index + // CHECK: %[[A0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xindex>, vector<4xindex> + // CHECK: %[[A1:.*]] = vector.load %[[ARG0]][%[[C4]]] : memref<8xindex>, vector<2xindex> + // CHECK: %[[A2:.*]] = memref.load %[[ARG0]][%[[C6]]] : memref<8xindex> + // CHECK: %[[B0:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xindex>, vector<4xindex> + // CHECK: %[[B1:.*]] = vector.load %[[ARG1]][%[[C4]]] : memref<8xindex>, vector<2xindex> + // CHECK: %[[B2:.*]] = memref.load %[[ARG1]][%[[C6]]] : memref<8xindex> + // CHECK: %[[RES0:.*]] = arith.addi %[[A0]], %[[B0]] : vector<4xindex> + // CHECK: %[[RES1:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xindex> + // CHECK: %[[RES2:.*]] = arith.addi %[[A2]], %[[B2]] : index + // CHECK: vector.store %[[RES0]], %[[ARG0]][%[[C0]]] : memref<8xindex>, vector<4xindex> + // CHECK: vector.store %[[RES1]], %[[ARG0]][%[[C4]]] : memref<8xindex>, vector<2xindex> + // CHECK: memref.store %[[RES2]], %[[ARG0]][%[[C6]]] : memref<8xindex> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + + %0 = memref.load %arg0[%c0] : memref<8xindex> + %1 = memref.load %arg0[%c1] : memref<8xindex> + %2 = memref.load %arg0[%c2] : memref<8xindex> + %3 = memref.load %arg0[%c3] : memref<8xindex> + %4 = memref.load %arg0[%c4] : memref<8xindex> + %5 = memref.load %arg0[%c5] : memref<8xindex> + %6 = memref.load %arg0[%c6] : memref<8xindex> + + %7 = memref.load %arg1[%c0] : memref<8xindex> + %8 = memref.load %arg1[%c1] : memref<8xindex> + %9 = memref.load %arg1[%c2] : memref<8xindex> + %10 = memref.load %arg1[%c3] : memref<8xindex> + %11 = memref.load %arg1[%c4] : memref<8xindex> + %12 = memref.load %arg1[%c5] : memref<8xindex> + %13 = memref.load %arg1[%c6] : memref<8xindex> + + %14 = arith.addi %0, %7 : index + %15 = arith.addi %1, %8 : index + %16 = arith.addi %2, %9 : index + %17 = arith.addi %3, %10 : index + %18 = arith.addi %4, %11 : index + %19 = arith.addi %5, %12 : index + %20 = arith.addi %6, %13 : index + + memref.store %14, %arg0[%c0] : memref<8xindex> + memref.store %15, %arg0[%c1] : memref<8xindex> + memref.store %16, %arg0[%c2] : memref<8xindex> + memref.store %17, %arg0[%c3] : memref<8xindex> + memref.store %18, %arg0[%c4] : memref<8xindex> + memref.store %19, %arg0[%c5] : memref<8xindex> + memref.store %20, %arg0[%c6] : memref<8xindex> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_size_mismatch +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[A1:.*]] = vector.extract_strided_slice %[[A]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[B1:.*]] = vector.extract_strided_slice %[[B]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + %8 = arith.addi %0, %4 : i32 + %9 = arith.addi %1, %5 : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_attrs_mismatch +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] overflow : vector<2xi32> + // CHECK: %[[V7:.*]] = arith.addi %[[V1]], %[[V3]] overflow : vector<2xi32> + // CHECK: vector.store %[[V6]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: vector.store %[[V7]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + %8 = arith.addi %0, %4 overflow : i32 + %9 = arith.addi %1, %5 overflow : i32 + %10 = arith.addi %2, %6 overflow : i32 + %11 = arith.addi %3, %7 overflow : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_interleaved +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %3 = memref.load %arg0[%c3] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + %11 = arith.addi %3, %7 : i32 + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %4 = memref.load %arg1[%c0] : memref<8xi32> + %8 = arith.addi %0, %4 : i32 + + %2 = memref.load %arg0[%c2] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %10 = arith.addi %2, %6 : i32 + + %1 = memref.load %arg0[%c1] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %9 = arith.addi %1, %5 : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_add_write +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32> +// CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>, + %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { + // Each load group have multiple 2 uses (in potentially different order) + // make sure we the both were vectorized. + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32> + // CHECK: %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32> + // CHECK: vector.store %[[ADD1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: vector.store %[[ADD2]], %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + %8 = arith.addi %0, %4 : i32 + %12 = arith.addi %0, %arg2 : i32 + + %13 = arith.addi %1, %arg3 : i32 + %9 = arith.addi %1, %5 : i32 + + %10 = arith.addi %2, %6 : i32 + %14 = arith.addi %2, %arg4 : i32 + + %15 = arith.addi %3, %arg5 : i32 + %11 = arith.addi %3, %7 : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + + memref.store %12, %arg1[%c0] : memref<8xi32> + memref.store %13, %arg1[%c1] : memref<8xi32> + memref.store %14, %arg1[%c2] : memref<8xi32> + memref.store %15, %arg1[%c3] : memref<8xi32> + + return +} + +// CHECK-LABEL: func @read_read_add_add +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32> +// CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +func.func @read_read_add_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>, + %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) -> + (i32, i32, i32, i32, i32, i32, i32, i32){ + // Each load group have multiple 2 uses (in potentially different order) + // make sure we the both were vectorized. + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32> + // CHECK: %[[R0:.*]] = vector.extract %[[ADD1]][0] : i32 from vector<4xi32> + // CHECK: %[[R1:.*]] = vector.extract %[[ADD1]][1] : i32 from vector<4xi32> + // CHECK: %[[R2:.*]] = vector.extract %[[ADD1]][2] : i32 from vector<4xi32> + // CHECK: %[[R3:.*]] = vector.extract %[[ADD1]][3] : i32 from vector<4xi32> + // CHECK: %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32> + // CHECK: %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32> + // CHECK: %[[R4:.*]] = vector.extract %[[ADD2]][0] : i32 from vector<4xi32> + // CHECK: %[[R5:.*]] = vector.extract %[[ADD2]][1] : i32 from vector<4xi32> + // CHECK: %[[R6:.*]] = vector.extract %[[ADD2]][2] : i32 from vector<4xi32> + // CHECK: %[[R7:.*]] = vector.extract %[[ADD2]][3] : i32 from vector<4xi32> + // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]] : i32, i32, i32, i32, i32, i32, i32, i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + %8 = arith.addi %0, %4 : i32 + %12 = arith.addi %0, %arg2 : i32 + + %13 = arith.addi %1, %arg3 : i32 + %9 = arith.addi %1, %5 : i32 + + %10 = arith.addi %2, %6 : i32 + %14 = arith.addi %2, %arg4 : i32 + + %15 = arith.addi %3, %arg5 : i32 + %11 = arith.addi %3, %7 : i32 + + return %8, %9, %10, %11, %12, %13, %14, %15 : i32, i32, i32, i32, i32, i32, i32, i32 +} + + +// CHECK-LABEL: func @read_read_add_add_vec +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> + (vector<2xi32>, vector<2xi32>){ + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<4xi32> + // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: return %[[V3]], %[[V4]] : vector<2xi32>, vector<2xi32> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + + %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32> + %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<2xi32> + + %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32> + %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<2xi32> + + %8 = arith.addi %0, %4 : vector<2xi32> + %10 = arith.addi %2, %6 : vector<2xi32> + + return %8, %10 : vector<2xi32>, vector<2xi32> +} + + +// CHECK-LABEL: func @read_read_add_add_vec1 +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_add_vec1(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> + (vector<1xi32>, vector<1xi32>){ + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32> + // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32> + // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32> + // CHECK: return %[[V3]], %[[V4]] : vector<1xi32>, vector<1xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32> + %2 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32> + + %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32> + %6 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32> + + %8 = arith.addi %0, %4 : vector<1xi32> + %10 = arith.addi %2, %6 : vector<1xi32> + + return %8, %10 : vector<1xi32>, vector<1xi32> +} + + +// CHECK-LABEL: func @read_read_add_add_vec0d +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_add_vec0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> + (vector, vector){ + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32> + // CHECK: %[[V3:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32> + // CHECK: %[[V4:.*]] = vector.splat %[[V3]] : vector + // CHECK: %[[V5:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32> + // CHECK: %[[V6:.*]] = vector.splat %[[V5]] : vector + // CHECK: return %[[V4]], %[[V6]] : vector, vector + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %0 = vector.load %arg0[%c0] : memref<8xi32>, vector + %2 = vector.load %arg0[%c1] : memref<8xi32>, vector + + %4 = vector.load %arg1[%c0] : memref<8xi32>, vector + %6 = vector.load %arg1[%c1] : memref<8xi32>, vector + + %8 = arith.addi %0, %4 : vector + %10 = arith.addi %2, %6 : vector + + return %8, %10 : vector, vector +} + + +func.func private @use(i32) + +// CHECK-LABEL: func @read_read_add_write_interleaved_use +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32> + // CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32> + // CHECK: call @use(%[[V0]]) : (i32) -> () + // CHECK: %[[V2:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: %[[V4:.*]] = memref.load %[[ARG0]][%[[C2]]] : memref<8xi32> + // CHECK: %[[V5:.*]] = memref.load %[[ARG1]][%[[C2]]] : memref<8xi32> + // CHECK: %[[V6:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32> + // CHECK: %[[V7:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32> + // CHECK: %[[V8:.*]] = vector.from_elements %[[V6]], %[[V7]], %[[V4]], %[[V0]] : vector<4xi32> + // CHECK: %[[V9:.*]] = vector.extract %[[V3]][0] : i32 from vector<2xi32> + // CHECK: %[[V10:.*]] = vector.extract %[[V3]][1] : i32 from vector<2xi32> + // CHECK: %[[V11:.*]] = vector.from_elements %[[V9]], %[[V10]], %[[V5]], %[[V1]] : vector<4xi32> + // CHECK: %[[V12:.*]] = arith.addi %[[V8]], %[[V11]] : vector<4xi32> + // CHECK: vector.store %[[V12]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %3 = memref.load %arg0[%c3] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + call @use(%3) : (i32) -> () + %11 = arith.addi %3, %7 : i32 + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %4 = memref.load %arg1[%c0] : memref<8xi32> + %8 = arith.addi %0, %4 : i32 + + %2 = memref.load %arg0[%c2] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %10 = arith.addi %2, %6 : i32 + + %1 = memref.load %arg0[%c1] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %9 = arith.addi %1, %5 : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + + return +} + + +// CHECK-LABEL: func @read_read_add_write_interleaved_use_add +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V1:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32> + // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V3:.*]] = vector.extract %[[V2]][3] : i32 from vector<4xi32> + // CHECK: %[[V4:.*]] = arith.subi %[[V1]], %[[V3]] : i32 + // CHECK: %[[V5:.*]] = arith.addi %[[V0]], %[[V2]] : vector<4xi32> + // CHECK: vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: call @use(%[[V4]]) : (i32) -> () + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %3 = memref.load %arg0[%c3] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + %12 = arith.subi %3, %7 : i32 + %11 = arith.addi %3, %7 : i32 + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %4 = memref.load %arg1[%c0] : memref<8xi32> + %8 = arith.addi %0, %4 : i32 + + %2 = memref.load %arg0[%c2] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %10 = arith.addi %2, %6 : i32 + + %1 = memref.load %arg0[%c1] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %9 = arith.addi %1, %5 : i32 + + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + + call @use(%12) : (i32) -> () + return +} + + +// CHECK-LABEL: func @different_blocks +// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>) +func.func @different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32> + // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: cf.br ^bb1 + // CHECK: ^bb1: + // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32> + // CHECK: %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] : vector<2xi32> + // CHECK: cf.br ^bb2 + // CHECK: ^bb2: + // CHECK: %[[V7:.*]] = arith.addi %[[V1]], %[[V3]] : vector<2xi32> + // CHECK: cf.br ^bb3 + // CHECK: ^bb3: + // CHECK: vector.store %[[V6]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32> + // CHECK: vector.store %[[V7]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %0 = memref.load %arg0[%c0] : memref<8xi32> + %1 = memref.load %arg0[%c1] : memref<8xi32> + %2 = memref.load %arg0[%c2] : memref<8xi32> + %3 = memref.load %arg0[%c3] : memref<8xi32> + + %4 = memref.load %arg1[%c0] : memref<8xi32> + %5 = memref.load %arg1[%c1] : memref<8xi32> + %6 = memref.load %arg1[%c2] : memref<8xi32> + %7 = memref.load %arg1[%c3] : memref<8xi32> + + cf.br ^bb0 + +^bb0: + %8 = arith.addi %0, %4 : i32 + %9 = arith.addi %1, %5 : i32 + cf.br ^bb1 + +^bb1: + %10 = arith.addi %2, %6 : i32 + %11 = arith.addi %3, %7 : i32 + cf.br ^bb2 + +^bb2: + memref.store %8, %arg0[%c0] : memref<8xi32> + memref.store %9, %arg0[%c1] : memref<8xi32> + memref.store %10, %arg0[%c2] : memref<8xi32> + memref.store %11, %arg0[%c3] : memref<8xi32> + + return +}