- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[mlir][x86vector] Lower vector.contract to sequence of FMAs #163382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| ✅ With the latest revision this PR passed the C/C++ code formatter. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round or review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment on the difference between 3 and 4 dims, and what they mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code looks almost identical (modulo indices). Can you select the indices based on dimCount and then have the code done once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: to avoid needing additional comments, I suggest to use longer function names, like loadAccumulatorBeforeGEMM.
Though, it would be nice to document what you could do with the iter args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is the public API, I recommend to use more expressive function names. Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would cache this into variables to make the code more expressive.
  int64_t nBlocks = N / vectorSize;
  bool skinny = (nBlocks < M);
| I'd also suggest checking if sth similar can't be achieved by combination of existing transforms like: transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
transform.apply_patterns.vector.lower_outerproduct | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
second round, mostly code structure.
To continue review, we'll need a few tests to validate the assumptions in the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary space between if and block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern (N / vectorSize) < M is spread across the code without connecting them and is making it hard to understand the code, especially when the code nested are almost identical.
There should be a way to prepare the indices, vectors and required ops in one block at the beginning and then (mostly) have a single sequence based on those variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function and local variable above with the same name may be confusing.
also, this function should be static, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid creating multiple functions with the same name. While C++ can resolve without issues, humans are not very good at it.
Also, static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the namespace? Did that not resolve correctly? If not, why not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you use the 4/3 split to get the loop below, so you could cache the lop itself instead and just use it here and below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is getNestedLoop guaranteed to return at least 2 elements?
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Arun Thangamani (arun-thmn) ChangesThis PR lowers  Patch is 48.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163382.diff 12 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
 
 add_mlir_interface(X86VectorInterfaces)
 add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..9db2b36a2a8aa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,37 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_nanokernel_lowering",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operation can be lowered to target
+    specific nanokernels.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+
+  let assemblyFormat = [{
+    (`vector_size` `=` $vector_size^)? attr-dict
+  }];
+
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..6410c12265f12 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
 
 #include "mlir/IR/Value.h"
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 
 class ImplicitLocOpBuilder;
@@ -79,6 +83,14 @@ struct MaskHelper {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Transforms a scheduled pattern to lower a tiled batch or batch-reduce
+// vector contraction into a sequence of nanokernels.
+// The transformation is tailored to the target machine architecture
+// and guided by the user-specified vector size.
+void populateVectorContractNanokernelLoweringPatterns(
+    RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+  X86VectorTransformOps.cpp
+
+  DEPENDS
+  MLIRX86VectorTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRX86VectorDialect
+  MLIRX86VectorTransforms
+  )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..e003e3ad7cd08
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,61 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateVectorContractNanokernelLoweringPatterns(patterns,
+                                                              getVectorSize());
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          X86VectorTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      X86VectorTransformDialectExtension)
+
+  X86VectorTransformDialectExtension() {
+    declareGeneratedDialect<x86vector::X86VectorDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+    DialectRegistry ®istry) {
+  registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..da377763331f2 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRX86VectorTransforms
   AVXTranspose.cpp
   LegalizeForLLVMExport.cpp
+  NanoKernels.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
new file mode 100644
index 0000000000000..4d0906a2ec057
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -0,0 +1,659 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
+static FailureOr<SmallVector<scf::ForOp>>
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+                       MatMulType matmulType) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+  // It is register tiled loop structure on batch (or reduce) matmul
+  // (M->N->(reduce)->K).
+  for (unsigned int i = 0; i < dimCount; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+    SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+    MatMulType matmulType) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  if (matmulType == MatMulType::BatchReduce) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivReduction = loops[1].getInductionVar();
+    if (ivReduction != subviewOpLhsOffsets[0] ||
+        ivReduction != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[2].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+      return failure();
+
+    Value ivM = loops[3].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  if (matmulType == MatMulType::Batch) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[1].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivM = loops[2].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  return success();
+}
+
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+                          Type elementType, unsigned int M, unsigned int N,
+                          unsigned int vectorSize, Value subviewOpAcc) {
+
+  SmallVector<Value> accumulators;
+
+  // Initialize local variable on assumption that M tile is larger than N
+  unsigned int outerBound = M;
+  unsigned int innerBound = N;
+
+  unsigned int outerStep = 1;
+  unsigned int innerStep = vectorSize;
+
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+  }
+
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
+      Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+      Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+      if (isNTileLarge) {
+        indexOp_A = indexOp_B;
+        indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
+      }
+
+      auto valueCRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+          ValueRange{indexOp_A, indexOp_B});
+      accumulators.push_back(valueCRow);
+    }
+  }
+
+  return accumulators;
+}
+
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  load_B2 = load B[32-47] into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+                    unsigned int vectorSize, unsigned int vnni, unsigned int M,
+                    unsigned int N, ValueRange acc, Value matA, Value matB,
+                    MatMulType matmulType) {
+
+  SmallVector<Value> accumulators;
+  SmallVector<Value> matLoad;
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+  // Start with assumption that M tile size is smaller and create  the
+  // helper variables
+  unsigned int outerBound = M;
+  unsigned int outerStep = 1;
+
+  unsigned int innerBound = N;
+  unsigned int innerStep = vectorSize;
+
+  Value outerMatrix = matA;
+  Value innerMatrix = matB;
+
+  unsigned int outerVectSize = vnni;
+  unsigned int innerVectSize = vectorSize;
+
+  unsigned int fmaBound = M;
+
+  // update helper variables if N tile size is smaller
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (!isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+
+    outerMatrix = matB;
+    innerMatrix = matA;
+
+    outerVectSize = vectorSize;
+    innerVectSize = vnni;
+
+    fmaBound = N / vectorSize;
+  }
+
+  // Load all the element of A or B matrix
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+    Value valueRow;
+
+    if (isNTileLarge) {
+
+      // With the assumption as batch-reduce matmul initialize reduction, M, and
+      // K dimension.
+      SmallVector<Value> index = {c0, indexOp_i, c0};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      // With the assumption as batch-reduce matmul initialize reduction, K, and
+      // N dimension.
+      SmallVector<Value> index = {c0, c0, indexOp_i};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load.
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+    }
+
+    matLoad.push_back(valueRow);
+  }
+
+  // Load elements of A/B Matrix one at a time and compute FMA
+  for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+    Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+    Value valueRow;
+
+    if (!isNTileLarge) {
+      SmallVector<Value> index = {c0, indexOp_j, c0};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, ValueRange(index));
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      SmallVector<Value> index = {c0, c0, indexOp_j};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, index);
+    }
+
+    // FMAs
+    for (unsigned int i = 0; i < fmaBound; i = i + 1) {
+...
[truncated]
 | 
| @rengolin @adam-smnk @rolfmorel wrapped the pass as  | 
|  | ||
| // Rewriter pattern for vector.contract operation. | ||
| // Input: vector.contract with tiled dimensions (batch or batch-matmul) | ||
| // Matching Pattern: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are prologue and epilogue ops supported in any way?
| return list; | ||
| } | ||
|  | ||
| static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be somehow relaxed to avoid explicit loop matching?
| return failure(); | ||
|  | ||
| auto subviewOpAcc = | ||
| vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>(); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are subviews required?
| // ----- | ||
|  | ||
| module { | ||
| func.func @not_tiled_no_rewriting(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd suggest using the func @negative_... naming scheme for test cases that don't match
| A high-level concern of mine is that this transform requires really specific IR to match at all. | 
| cc: @shahidact | 
| 
 +1 | 
This PR lowers
vector.contract(onlybatch-reduceorbatchmatmul) to a sequence ofFMAswith respect to thevector_size.