Skip to content

Conversation

@arun-thmn
Copy link
Contributor

@arun-thmn arun-thmn commented Oct 14, 2025

This PR lowers vector.contract (only batch-reduce or batch matmul) to a sequence of FMAs with respect to the vector_size.

@github-actions
Copy link

github-actions bot commented Oct 14, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round or review

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment on the difference between 3 and 4 dims, and what they mean.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks almost identical (modulo indices). Can you select the indices based on dimCount and then have the code done once?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to avoid needing additional comments, I suggest to use longer function names, like loadAccumulatorBeforeGEMM.

Though, it would be nice to document what you could do with the iter args.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is the public API, I recommend to use more expressive function names. Same as above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would cache this into variables to make the code more expressive.

  int64_t nBlocks = N / vectorSize;
  bool skinny = (nBlocks < M);

@adam-smnk
Copy link
Contributor

I'd also suggest checking if sth similar can't be achieved by combination of existing transforms like:

transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
transform.apply_patterns.vector.lower_outerproduct

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

second round, mostly code structure.

To continue review, we'll need a few tests to validate the assumptions in the PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary space between if and block

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern (N / vectorSize) < M is spread across the code without connecting them and is making it hard to understand the code, especially when the code nested are almost identical.

There should be a way to prepare the indices, vectors and required ops in one block at the beginning and then (mostly) have a single sequence based on those variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function and local variable above with the same name may be confusing.

also, this function should be static, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid creating multiple functions with the same name. While C++ can resolve without issues, humans are not very good at it.

Also, static?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the namespace? Did that not resolve correctly? If not, why not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you use the 4/3 split to get the loop below, so you could cache the lop itself instead and just use it here and below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is getNestedLoop guaranteed to return at least 2 elements?

@arun-thmn arun-thmn marked this pull request as ready for review October 27, 2025 10:59
@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Arun Thangamani (arun-thmn)

Changes

This PR lowers vector.contract (only batch-reduce or batch matmul) to a sequence of FMAs with respect to the vector_size.


Patch is 48.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163382.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt (+2)
  • (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt (+4)
  • (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h (+31)
  • (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+37)
  • (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+12)
  • (modified) mlir/lib/Dialect/X86Vector/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt (+17)
  • (added) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+61)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp (+659)
  • (modified) mlir/lib/RegisterAllExtensions.cpp (+2)
  • (added) mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir (+215)
diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
 
 add_mlir_interface(X86VectorInterfaces)
 add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..9db2b36a2a8aa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,37 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_nanokernel_lowering",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operation can be lowered to target
+    specific nanokernels.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+
+  let assemblyFormat = [{
+    (`vector_size` `=` $vector_size^)? attr-dict
+  }];
+
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..6410c12265f12 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
 
 #include "mlir/IR/Value.h"
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 
 class ImplicitLocOpBuilder;
@@ -79,6 +83,14 @@ struct MaskHelper {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Transforms a scheduled pattern to lower a tiled batch or batch-reduce
+// vector contraction into a sequence of nanokernels.
+// The transformation is tailored to the target machine architecture
+// and guided by the user-specified vector size.
+void populateVectorContractNanokernelLoweringPatterns(
+    RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+  X86VectorTransformOps.cpp
+
+  DEPENDS
+  MLIRX86VectorTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRX86VectorDialect
+  MLIRX86VectorTransforms
+  )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..e003e3ad7cd08
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,61 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateVectorContractNanokernelLoweringPatterns(patterns,
+                                                              getVectorSize());
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          X86VectorTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      X86VectorTransformDialectExtension)
+
+  X86VectorTransformDialectExtension() {
+    declareGeneratedDialect<x86vector::X86VectorDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..da377763331f2 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRX86VectorTransforms
   AVXTranspose.cpp
   LegalizeForLLVMExport.cpp
+  NanoKernels.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
new file mode 100644
index 0000000000000..4d0906a2ec057
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -0,0 +1,659 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
+static FailureOr<SmallVector<scf::ForOp>>
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+                       MatMulType matmulType) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+  // It is register tiled loop structure on batch (or reduce) matmul
+  // (M->N->(reduce)->K).
+  for (unsigned int i = 0; i < dimCount; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+    SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+    MatMulType matmulType) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  if (matmulType == MatMulType::BatchReduce) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivReduction = loops[1].getInductionVar();
+    if (ivReduction != subviewOpLhsOffsets[0] ||
+        ivReduction != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[2].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+      return failure();
+
+    Value ivM = loops[3].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  if (matmulType == MatMulType::Batch) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[1].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivM = loops[2].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  return success();
+}
+
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+                          Type elementType, unsigned int M, unsigned int N,
+                          unsigned int vectorSize, Value subviewOpAcc) {
+
+  SmallVector<Value> accumulators;
+
+  // Initialize local variable on assumption that M tile is larger than N
+  unsigned int outerBound = M;
+  unsigned int innerBound = N;
+
+  unsigned int outerStep = 1;
+  unsigned int innerStep = vectorSize;
+
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+  }
+
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
+      Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+      Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+      if (isNTileLarge) {
+        indexOp_A = indexOp_B;
+        indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
+      }
+
+      auto valueCRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+          ValueRange{indexOp_A, indexOp_B});
+      accumulators.push_back(valueCRow);
+    }
+  }
+
+  return accumulators;
+}
+
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  load_B2 = load B[32-47] into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+                    unsigned int vectorSize, unsigned int vnni, unsigned int M,
+                    unsigned int N, ValueRange acc, Value matA, Value matB,
+                    MatMulType matmulType) {
+
+  SmallVector<Value> accumulators;
+  SmallVector<Value> matLoad;
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+  // Start with assumption that M tile size is smaller and create  the
+  // helper variables
+  unsigned int outerBound = M;
+  unsigned int outerStep = 1;
+
+  unsigned int innerBound = N;
+  unsigned int innerStep = vectorSize;
+
+  Value outerMatrix = matA;
+  Value innerMatrix = matB;
+
+  unsigned int outerVectSize = vnni;
+  unsigned int innerVectSize = vectorSize;
+
+  unsigned int fmaBound = M;
+
+  // update helper variables if N tile size is smaller
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (!isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+
+    outerMatrix = matB;
+    innerMatrix = matA;
+
+    outerVectSize = vectorSize;
+    innerVectSize = vnni;
+
+    fmaBound = N / vectorSize;
+  }
+
+  // Load all the element of A or B matrix
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+    Value valueRow;
+
+    if (isNTileLarge) {
+
+      // With the assumption as batch-reduce matmul initialize reduction, M, and
+      // K dimension.
+      SmallVector<Value> index = {c0, indexOp_i, c0};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      // With the assumption as batch-reduce matmul initialize reduction, K, and
+      // N dimension.
+      SmallVector<Value> index = {c0, c0, indexOp_i};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load.
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+    }
+
+    matLoad.push_back(valueRow);
+  }
+
+  // Load elements of A/B Matrix one at a time and compute FMA
+  for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+    Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+    Value valueRow;
+
+    if (!isNTileLarge) {
+      SmallVector<Value> index = {c0, indexOp_j, c0};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, ValueRange(index));
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      SmallVector<Value> index = {c0, c0, indexOp_j};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, index);
+    }
+
+    // FMAs
+    for (unsigned int i = 0; i < fmaBound; i = i + 1) {
+...
[truncated]

@arun-thmn
Copy link
Contributor Author

@rengolin @adam-smnk @rolfmorel wrapped the pass as transform schedule. Please have a look.


// Rewriter pattern for vector.contract operation.
// Input: vector.contract with tiled dimensions (batch or batch-matmul)
// Matching Pattern:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are prologue and epilogue ops supported in any way?

return list;
}

static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be somehow relaxed to avoid explicit loop matching?

return failure();

auto subviewOpAcc =
vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are subviews required?

// -----

module {
func.func @not_tiled_no_rewriting(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd suggest using the func @negative_... naming scheme for test cases that don't match

@adam-smnk
Copy link
Contributor

A high-level concern of mine is that this transform requires really specific IR to match at all.
It might be difficult to use it in a larger schedule.

@arun-thmn
Copy link
Contributor Author

cc: @shahidact

@shahidact
Copy link
Contributor

A high-level concern of mine is that this transform requires really specific IR to match at all. It might be difficult to use it in a larger schedule.

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants