From ce42a09598082c2ca7562bbf80312f5b9c3d60f8 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 14 Oct 2025 05:29:32 -0700 Subject: [PATCH 1/7] lower vector.contract to sequence of FMAs --- .../mlir/Dialect/X86Vector/Transforms.h | 9 + .../X86Vector/Transforms/CMakeLists.txt | 1 + .../X86Vector/Transforms/NanoKernels.cpp | 483 ++++++++++++++++++ 3 files changed, 493 insertions(+) create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index d54111ca41e69..cde890038f20a 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -11,6 +11,10 @@ #include "mlir/IR/Value.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" + namespace mlir { class ImplicitLocOpBuilder; @@ -79,6 +83,11 @@ struct MaskHelper { } }; +//===----------------------------------------------------------------------===// +// Nano-kernels +LogicalResult nanoKernels(RewriterBase &rewriter, + vector::ContractionOp contractOp, int64_t vectorSize); + //===----------------------------------------------------------------------===// /// Helpers extracted from: /// - clang/lib/Headers/avxintrin.h diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266afe9e8f..da377763331f2 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + NanoKernels.cpp LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp new file mode 100644 index 0000000000000..bc03b567e06ff --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp @@ -0,0 +1,483 @@ +//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements matmul rewrites as nanokernels with respect to target +// machine for FP32 and BF16 (TODO) types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +static FailureOr> +getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) { + SmallVector list; + Operation *current = contractOp; + // It is register tiled loop structure on batch reduce matmul + // (M->N->Batch-reduce->K). + for (int i = 0; i < dimCount; i++) { + Operation *parent = current->getParentOfType(); + if (!parent) + return failure(); + list.push_back(dyn_cast(parent)); + current = parent; + } + return list; +} + +static LogicalResult checkNestedLoop(SmallVector loops, + SmallVector subviews, + int64_t dimCount) { + auto subviewOpLhsOffsets = subviews[0].getOffsets(); + auto subviewOpRhsOffsets = subviews[1].getOffsets(); + auto subviewOpAccOffsets = subviews[2].getOffsets(); + + if (dimCount == 4) { + Value ivK = loops[0].getInductionVar(); + if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1]) + return failure(); + + Value ivReduction = loops[1].getInductionVar(); + if (ivReduction != subviewOpLhsOffsets[0] || + ivReduction != subviewOpRhsOffsets[0]) + return failure(); + + Value ivN = loops[2].getInductionVar(); + if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2]) + return failure(); + + Value ivM = loops[3].getInductionVar(); + if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0]) + return failure(); + } + + if (dimCount == 3) { + Value ivK = loops[0].getInductionVar(); + if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0]) + return failure(); + + Value ivN = loops[1].getInductionVar(); + if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1]) + return failure(); + + Value ivM = loops[2].getInductionVar(); + if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0]) + return failure(); + } + + return success(); +} + +static SmallVector loadAcc(Location loc, RewriterBase &rewriter, + Type elementType, int64_t M, int64_t N, + int64_t vectorSize, Value subviewOpAcc) { + + SmallVector loopItrArgs; + int64_t outerBound = M; + int64_t innerBound = N; + + int64_t outerStep = 1; + int64_t innerStep = vectorSize; + + if ((N / vectorSize) > M) { + outerBound = N; + innerBound = M; + + outerStep = vectorSize; + innerStep = 1; + } + + for (int i = 0; i < outerBound; i = i + outerStep) { + for (int j = 0; j < innerBound; j = j + innerStep) { + Value indexOp_A = rewriter.create(loc, i); + Value indexOp_B = rewriter.create(loc, j); + + if ((N / vectorSize) > M) { + indexOp_A = indexOp_B; + indexOp_B = rewriter.create(loc, i); + } + + auto valueCRow = rewriter.create( + loc, VectorType::get(vectorSize, elementType), subviewOpAcc, + ValueRange{indexOp_A, indexOp_B}); + loopItrArgs.push_back(valueCRow); + } + } + + return loopItrArgs; +} + +SmallVector nanoKernels(RewriterBase &rewriter, Location loc, + Type elementType, int64_t vectorSize, + int64_t vnni, int64_t M, int64_t N, + ValueRange acc, Value matA, Value matB, + int64_t dimCount) { + + SmallVector accVector; + SmallVector matLoad; + Value c0 = rewriter.create(loc, 0); + + int64_t outerBound = M; + int64_t outerStep = 1; + + int64_t innerBound = N; + int64_t innerStep = vectorSize; + + Value outerMatrix = matA; + Value innerMatrix = matB; + + int64_t outerVectSize = vnni; + int64_t innerVectSize = vectorSize; + + int64_t fmaBound = M; + + if ((N / vectorSize) < M) { + outerBound = N; + innerBound = M; + + outerStep = vectorSize; + innerStep = 1; + + outerMatrix = matB; + innerMatrix = matA; + + outerVectSize = vectorSize; + innerVectSize = vnni; + + fmaBound = N / vectorSize; + } + + for (int i = 0; i < outerBound; i = i + outerStep) { + Value indexOp_i = rewriter.create(loc, i); + Value valueRow; + + if ((N / vectorSize) > M) { + + SmallVector index = {c0, indexOp_i, c0}; + if (dimCount == 3) { + index.erase(index.begin()); + } + + Value row = rewriter.create( + loc, VectorType::get(outerVectSize, elementType), outerMatrix, index); + valueRow = rewriter.create( + loc, VectorType::get(vectorSize, rewriter.getF32Type()), row); + } else { + + SmallVector index = {c0, c0, indexOp_i}; + if (dimCount == 3) { + index.erase(index.begin()); + } + + valueRow = rewriter.create( + loc, VectorType::get(outerVectSize, elementType), outerMatrix, index); + } + + matLoad.push_back(valueRow); + } + + for (int j = 0, k = 0; j < innerBound; j = j + innerStep) { + Value indexOp_j = rewriter.create(loc, j); + Value valueRow; + + if ((N / vectorSize) < M) { + SmallVector index = {c0, indexOp_j, c0}; + if (dimCount == 3) { + index.erase(index.begin()); + } + Value row = rewriter.create( + loc, VectorType::get(innerVectSize, elementType), innerMatrix, + ValueRange(index)); + valueRow = rewriter.create( + loc, VectorType::get(vectorSize, rewriter.getF32Type()), row); + } else { + + SmallVector index = {c0, c0, indexOp_j}; + if (dimCount == 3) { + index.erase(index.begin()); + } + + valueRow = rewriter.create( + loc, VectorType::get(innerVectSize, elementType), innerMatrix, index); + } + + for (int i = 0; i < fmaBound; i = i + 1) { + auto fmaOdd = + rewriter.create(loc, matLoad[i], valueRow, acc[k]); + k++; + accVector.push_back(fmaOdd); + } + } + + return accVector; +} + +Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType, + SmallVector FMAs, Value accVec, int64_t vecSize, + int64_t M, int64_t N) { + + auto strides = rewriter.getI64ArrayAttr({1}); + if ((N / vecSize) > M) { + for (int j = 0, k = 0; j < (N / vecSize); j++) { + for (int i = 0; i < M; i++) { + int64_t off = (j * vecSize) + (i * N); + auto offsets = rewriter.getI64ArrayAttr({off}); + accVec = rewriter.create( + loc, vecType, FMAs[k], accVec, offsets, strides); + k++; + } + } + + } else { + for (int i = 0, k = 0; i < M * N; i = i + vecSize) { + auto offsets = rewriter.getI64ArrayAttr({i}); + accVec = rewriter.create( + loc, vecType, FMAs[k], accVec, offsets, strides); + k++; + } + } + return accVec; +} + +scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, + vector::TransferReadOp vectorReadOpLhs, + vector::TransferReadOp vectorReadOpRhs, + Value ivNewReductionForOp, Type elementType, + int64_t vectorSize, int64_t vnni, int64_t M, int64_t N, + ValueRange iterArgsNewReductionForOp, int64_t dimCount) { + auto newKForOp = rewriter.create( + kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), + kForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, + Value ivNewKForOp, ValueRange iterArgsNewKForOp) { + IRMapping mapping; + mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1), + ivNewReductionForOp); + mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3), + ivNewKForOp); + auto lhsClone = rewriterNewKForOp.clone( + *vectorReadOpLhs.getBase().getDefiningOp(), mapping); + + IRMapping rhsMapping; + rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1), + ivNewReductionForOp); + rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2), + ivNewKForOp); + auto rhsClone = rewriterNewKForOp.clone( + *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping); + + auto evenFMAs = + nanoKernels(rewriter, kForOp.getLoc(), elementType, vectorSize, + vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0), + rhsClone->getResult(0), dimCount); + + rewriterNewKForOp.create(locNewKForOp, evenFMAs); + }); + + return newKForOp; +} + +scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, + vector::TransferReadOp vectorReadOpLhs, + vector::TransferReadOp vectorReadOpRhs, Type elementType, + int64_t vectorSize, int64_t vnni, int64_t M, int64_t N, + ValueRange iterArgsNewReductionForOp, int64_t dimCount) { + + auto newKForOp = rewriter.create( + kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), + kForOp.getStep(), iterArgsNewReductionForOp, + [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, + Value ivNewKForOp, ValueRange iterArgsNewKForOp) { + IRMapping mapping; + mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(2), + ivNewKForOp); + auto lhsClone = rewriterNewKForOp.clone( + *vectorReadOpLhs.getBase().getDefiningOp(), mapping); + + IRMapping rhsMapping; + rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1), + ivNewKForOp); + auto rhsClone = rewriterNewKForOp.clone( + *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping); + + auto evenFMAs = + nanoKernels(rewriter, loc, elementType, vectorSize, vnni, M, N, + iterArgsNewKForOp, lhsClone->getResult(0), + rhsClone->getResult(0), dimCount); + + rewriterNewKForOp.create(locNewKForOp, evenFMAs); + }); + + return newKForOp; +} + +LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter, + vector::ContractionOp contractOp, int64_t vectorSize) { + auto loc = contractOp.getLoc(); + + if (contractOp.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + } + + auto dimCount = contractOp.getRhsType().getRank() + 1; + + if ((dimCount != 3) && (dimCount != 4)) + return rewriter.notifyMatchFailure(contractOp, + "Expects batch-reduce or batch matmuls"); + + // Get the M, N, K, and batch-reduce loops + auto loops = getNestedLoop(contractOp, dimCount); + if (failed(loops)) + return rewriter.notifyMatchFailure(contractOp, + "Invalid loop nest in contract pattern"); + + auto nestedLoops = *loops; + scf::ForOp kForOp = nestedLoops[0]; + scf::ForOp reductionForOp; + + vector::TransferReadOp vectorReadOpAcc; + + if (dimCount == 4) { + reductionForOp = nestedLoops[1]; + vectorReadOpAcc = + reductionForOp.getInitArgs()[0].getDefiningOp(); + } + + if (dimCount == 3) { + vectorReadOpAcc = + kForOp.getInitArgs()[0].getDefiningOp(); + } + + auto vectorReadOpLhs = + contractOp.getLhs().getDefiningOp(); + auto vectorReadOpRhs = + contractOp.getRhs().getDefiningOp(); + + if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs) + return failure(); + + auto subviewOpAcc = + vectorReadOpAcc.getOperand(0).getDefiningOp(); + auto subviewOpLhs = + vectorReadOpLhs.getOperand(0).getDefiningOp(); + auto subviewOpRhs = + vectorReadOpRhs.getOperand(0).getDefiningOp(); + + if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs) + return failure(); + + SmallVector subviews; + subviews.push_back(subviewOpLhs); + subviews.push_back(subviewOpRhs); + subviews.push_back(subviewOpAcc); + + // The M, N, K, and batch-reduce loop iv should match the iv's + // used in the subviews + auto checkLoops = checkNestedLoop(*loops, subviews, dimCount); + if (failed(checkLoops)) + return rewriter.notifyMatchFailure( + contractOp, "Loops doesn't match the iv in subviews"); + + auto elementType = + (cast(subviewOpLhs.getType())).getElementType(); + + // TODO: Support for BF16 Type + if (!elementType.isF32()) + return rewriter.notifyMatchFailure( + contractOp, "Only, FP32 type is supported"); + + auto lhsType = dyn_cast(vectorReadOpLhs.getType()); + auto rhsType = dyn_cast(vectorReadOpRhs.getType()); + + // Get M, N, and K dimension size + int64_t M = lhsType.getDimSize(lhsType.getRank() - 2); + int64_t N = rhsType.getDimSize(rhsType.getRank() - 1); + int64_t K = lhsType.getDimSize(lhsType.getRank() - 1); + int64_t vnni = 1; + + if (K != 1) + return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1"); + + if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1) + return rewriter.notifyMatchFailure(contractOp, + "The reduction-dim should be 1"); + + + if (dimCount == 4) + rewriter.setInsertionPoint(reductionForOp); + + if (dimCount == 3) + rewriter.setInsertionPoint(kForOp); + + // Load MxN C sub matrix into acc vectors (e.g, ) + SmallVector loopItrArgs = + loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc); + + // Create the batch-reduce and K-loop with acc vectors as the loop + // iterargs (batch-reduce matmul) + nanokernel generation + scf::ForOp newLoop; + if (dimCount == 4) { + newLoop = rewriter.create( + reductionForOp.getLoc(), reductionForOp.getLowerBound(), + reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs, + [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, + Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) { + scf::ForOp newKForOp = createLoop( + rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, + ivNewReductionForOp, elementType, vectorSize, vnni, M, N, + iterArgsNewReductionForOp, dimCount); + + rewriterNewReductionForOp.create( + locNewReductionForOp, newKForOp.getResults()); + }); + } + + // Create only the K-loop (batch matmul) + nanokernel generation + if (dimCount == 3) { + newLoop = + createLoop(rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, + elementType, vectorSize, vnni, M, N, loopItrArgs, dimCount); + } + + + // Combine all acc vectors into a MxN C matrix + auto vecType = VectorType::get({M * N}, rewriter.getF32Type()); + auto zeroAttr = + DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0)); + Value accVec = rewriter.create(loc, vecType, zeroAttr); + + accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N); + + auto accTy = dyn_cast(contractOp.getAccType()); + auto reshapeAcc = rewriter.create(loc, accTy, accVec); + + // Replace all the use of vector.contract with results of nanokernels + if (dimCount == 4) + rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc); + + if (dimCount == 3) + rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc); + + return success(); +} From 176cb06fed4c7358603dba8453ab3c1d79910dbd Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 22 Oct 2025 04:29:22 -0700 Subject: [PATCH 2/7] initial code to wrap as a Transform dialect --- .../mlir/Dialect/X86Vector/CMakeLists.txt | 2 + .../X86Vector/TransformOps/CMakeLists.txt | 4 ++ .../TransformOps/X86VectorTransformOps.td | 38 +++++++++++++++++ .../mlir/Dialect/X86Vector/Transforms.h | 4 +- mlir/lib/Dialect/X86Vector/CMakeLists.txt | 1 + .../X86Vector/TransformOps/CMakeLists.txt | 20 +++++++++ .../TransformOps/X86VectorTransformOps.cpp | 42 +++++++++++++++++++ .../X86Vector/Transforms/NanoKernels.cpp | 21 +++++++++- 8 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt index 0fe01824b8248..bbe8e4eb892dd 100644 --- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector) add_mlir_interface(X86VectorInterfaces) add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..6f377e10fa8f8 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td) +mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs) +add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td new file mode 100644 index 0000000000000..23f0eebaebe34 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -0,0 +1,38 @@ +//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_TRANSFORM_OPS +#define X86VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/RegionKindInterface.td" + +def ApplyVectorContractNanokernelLoweringPatternsOp : Op]> { + let description = [{ + Indicates that vector contract operation can be lowered to target + specific nanokernels. + }]; + + //let arguments = (ins DefaultValuedAttr:$vector_size); + + //let assemblyFormat = [{ + //(`vector_size` `=` $vector_size^)? attr-dict + //}]; + + let assemblyFormat = "attr-dict"; +} + + +#endif // X86VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index cde890038f20a..6ddb4e542ba54 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -85,8 +85,8 @@ struct MaskHelper { //===----------------------------------------------------------------------===// // Nano-kernels -LogicalResult nanoKernels(RewriterBase &rewriter, - vector::ContractionOp contractOp, int64_t vectorSize); +void populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns); //RewriterBase &rewriter, + //vector::ContractionOp contractOp, int64_t vectorSize); //===----------------------------------------------------------------------===// /// Helpers extracted from: diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..8814547620e58 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRVectorToLLVM + MLIRVectorTransforms + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRVectorDialect + MLIRVectorToSCF + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000000000..1570972db1855 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,42 @@ +//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" + +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + + + +void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorTransferLoweringPatterns(patterns);//, + //getVectorSize()); +} + + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} + diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp index bc03b567e06ff..5aeba6cad0445 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp @@ -21,6 +21,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/PatternMatch.h" @@ -331,9 +335,16 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, return newKForOp; } -LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter, - vector::ContractionOp contractOp, int64_t vectorSize) { +//LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter, + // vector::ContractionOp contractOp)//, int64_t vectorSize) { + +struct VectorContractNanokernelLowering final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { auto loc = contractOp.getLoc(); + int64_t vectorSize = 16; if (contractOp.getKind() != vector::CombiningKind::ADD) { return rewriter.notifyMatchFailure(contractOp, @@ -481,3 +492,9 @@ LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter, return success(); } +}; + + +void x86vector::populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} From 15f4c5d54ccf3c820097dc8f4e2dc8d2999e3fe8 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Wed, 22 Oct 2025 07:37:41 -0700 Subject: [PATCH 3/7] fix build errors in transform code --- .../TransformOps/X86VectorTransformOps.h | 36 +++++++++++++++++++ .../X86Vector/TransformOps/CMakeLists.txt | 2 +- .../TransformOps/X86VectorTransformOps.cpp | 30 +++++++++++++--- 3 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h new file mode 100644 index 0000000000000..abb5da75e5bfd --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h @@ -0,0 +1,36 @@ +//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace x86vector { +} // namespace vector +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace x86vector { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace vector +} // namespace mlir + +#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt index 8814547620e58..5f85f7af60d01 100644 --- a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -14,7 +14,7 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps MLIRSideEffectInterfaces MLIRTransformDialect MLIRTransformDialectUtils - MLIRVectorDialect + MLIRX86VectorDialect MLIRVectorToSCF MLIRX86VectorTransforms ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp index 1570972db1855..5702106558409 100644 --- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -14,11 +14,13 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/X86Vector/Transforms.h" -#include "mlir/Dialect/Transform/IR/TransformAttrs.h" -#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" + #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + using namespace mlir; using namespace mlir::x86vector; using namespace mlir::transform; @@ -27,10 +29,31 @@ using namespace mlir::transform; void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorTransferLoweringPatterns(patterns);//, + x86vector::populateVectorContractNanokernelLoweringPatterns(patterns);//, //getVectorSize()); } +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace #define GET_OP_CLASSES #include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" @@ -39,4 +62,3 @@ void mlir::x86vector::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); } - From ca4e25291d23c9140fbbc4904f816437aa2d48bb Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Thu, 23 Oct 2025 23:08:27 -0700 Subject: [PATCH 4/7] Round1: env set up + clang format fix --- .../TransformOps/X86VectorTransformOps.h | 11 +- .../TransformOps/X86VectorTransformOps.td | 9 +- .../mlir/Dialect/X86Vector/Transforms.h | 4 +- .../X86Vector/TransformOps/CMakeLists.txt | 3 - .../TransformOps/X86VectorTransformOps.cpp | 23 +- .../X86Vector/Transforms/NanoKernels.cpp | 363 +++++++++--------- mlir/lib/RegisterAllExtensions.cpp | 2 + .../vector-contract-to-nanokernels.mlir | 48 +++ 8 files changed, 256 insertions(+), 207 deletions(-) create mode 100644 mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h index abb5da75e5bfd..e1d8b8762e799 100644 --- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h @@ -1,4 +1,4 @@ -//===- VectorTransformOps.h - Vector transform ops --------------*- C++ -*-===// +//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,13 +12,8 @@ #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" -namespace mlir { -namespace x86vector { -} // namespace vector -} // namespace mlir - //===----------------------------------------------------------------------===// -// Vector Transform Operations +// X86Vector Transform Operations //===----------------------------------------------------------------------===// #define GET_OP_CLASSES @@ -30,7 +25,7 @@ class DialectRegistry; namespace x86vector { void registerTransformDialectExtension(DialectRegistry ®istry); -} // namespace vector +} // namespace x86vector } // namespace mlir #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td index 23f0eebaebe34..9db2b36a2a8aa 100644 --- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -25,13 +25,12 @@ def ApplyVectorContractNanokernelLoweringPatternsOp : Op:$vector_size); + let arguments = (ins DefaultValuedAttr:$vector_size); - //let assemblyFormat = [{ - //(`vector_size` `=` $vector_size^)? attr-dict - //}]; + let assemblyFormat = [{ + (`vector_size` `=` $vector_size^)? attr-dict + }]; - let assemblyFormat = "attr-dict"; } diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index 6ddb4e542ba54..a9487adba002a 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -85,8 +85,8 @@ struct MaskHelper { //===----------------------------------------------------------------------===// // Nano-kernels -void populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns); //RewriterBase &rewriter, - //vector::ContractionOp contractOp, int64_t vectorSize); +void populateVectorContractNanokernelLoweringPatterns( + RewritePatternSet &patterns, std::optional vectorSize = 8); //===----------------------------------------------------------------------===// /// Helpers extracted from: diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt index 5f85f7af60d01..f4c9f8a05acbc 100644 --- a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -9,12 +9,9 @@ add_mlir_dialect_library(MLIRX86VectorTransformOps MLIRLLVMCommonConversion MLIRLLVMDialect MLIRVectorDialect - MLIRVectorToLLVM - MLIRVectorTransforms MLIRSideEffectInterfaces MLIRTransformDialect MLIRTransformDialectUtils MLIRX86VectorDialect - MLIRVectorToSCF MLIRX86VectorTransforms ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp index 5702106558409..e003e3ad7cd08 100644 --- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -1,4 +1,5 @@ -//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops --===// +//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops +//--===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,31 +7,26 @@ // //===----------------------------------------------------------------------===// - +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/X86Vector/Transforms.h" - -#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" -#include "mlir/Dialect/X86Vector/X86VectorDialect.h" - using namespace mlir; using namespace mlir::x86vector; using namespace mlir::transform; - - -void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::populatePatterns( - RewritePatternSet &patterns) { - x86vector::populateVectorContractNanokernelLoweringPatterns(patterns);//, - //getVectorSize()); +void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractNanokernelLoweringPatterns(patterns, + getVectorSize()); } //===----------------------------------------------------------------------===// @@ -42,7 +38,8 @@ class X86VectorTransformDialectExtension : public transform::TransformDialectExtension< X86VectorTransformDialectExtension> { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(X86VectorTransformDialectExtension) + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) X86VectorTransformDialectExtension() { declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp index 5aeba6cad0445..583334333a49d 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp @@ -16,30 +16,28 @@ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" - -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + using namespace mlir; using namespace mlir::vector; using namespace mlir::x86vector; static FailureOr> -getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) { +getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) { SmallVector list; Operation *current = contractOp; // It is register tiled loop structure on batch reduce matmul // (M->N->Batch-reduce->K). - for (int i = 0; i < dimCount; i++) { + for (unsigned int i = 0; i < dimCount; i++) { Operation *parent = current->getParentOfType(); if (!parent) return failure(); @@ -51,7 +49,7 @@ getNestedLoop(vector::ContractionOp contractOp, int64_t dimCount) { static LogicalResult checkNestedLoop(SmallVector loops, SmallVector subviews, - int64_t dimCount) { + unsigned int dimCount) { auto subviewOpLhsOffsets = subviews[0].getOffsets(); auto subviewOpRhsOffsets = subviews[1].getOffsets(); auto subviewOpAccOffsets = subviews[2].getOffsets(); @@ -93,15 +91,16 @@ static LogicalResult checkNestedLoop(SmallVector loops, } static SmallVector loadAcc(Location loc, RewriterBase &rewriter, - Type elementType, int64_t M, int64_t N, - int64_t vectorSize, Value subviewOpAcc) { + Type elementType, unsigned int M, + unsigned int N, unsigned int vectorSize, + Value subviewOpAcc) { SmallVector loopItrArgs; - int64_t outerBound = M; - int64_t innerBound = N; + unsigned int outerBound = M; + unsigned int innerBound = N; - int64_t outerStep = 1; - int64_t innerStep = vectorSize; + unsigned int outerStep = 1; + unsigned int innerStep = vectorSize; if ((N / vectorSize) > M) { outerBound = N; @@ -111,8 +110,8 @@ static SmallVector loadAcc(Location loc, RewriterBase &rewriter, innerStep = 1; } - for (int i = 0; i < outerBound; i = i + outerStep) { - for (int j = 0; j < innerBound; j = j + innerStep) { + for (unsigned int i = 0; i < outerBound; i = i + outerStep) { + for (unsigned int j = 0; j < innerBound; j = j + innerStep) { Value indexOp_A = rewriter.create(loc, i); Value indexOp_B = rewriter.create(loc, j); @@ -132,28 +131,28 @@ static SmallVector loadAcc(Location loc, RewriterBase &rewriter, } SmallVector nanoKernels(RewriterBase &rewriter, Location loc, - Type elementType, int64_t vectorSize, - int64_t vnni, int64_t M, int64_t N, - ValueRange acc, Value matA, Value matB, - int64_t dimCount) { + Type elementType, unsigned int vectorSize, + unsigned int vnni, unsigned int M, + unsigned int N, ValueRange acc, Value matA, + Value matB, unsigned int dimCount) { SmallVector accVector; SmallVector matLoad; Value c0 = rewriter.create(loc, 0); - int64_t outerBound = M; - int64_t outerStep = 1; + unsigned int outerBound = M; + unsigned int outerStep = 1; - int64_t innerBound = N; - int64_t innerStep = vectorSize; + unsigned int innerBound = N; + unsigned int innerStep = vectorSize; Value outerMatrix = matA; Value innerMatrix = matB; - int64_t outerVectSize = vnni; - int64_t innerVectSize = vectorSize; + unsigned int outerVectSize = vnni; + unsigned int innerVectSize = vectorSize; - int64_t fmaBound = M; + unsigned int fmaBound = M; if ((N / vectorSize) < M) { outerBound = N; @@ -171,7 +170,7 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, fmaBound = N / vectorSize; } - for (int i = 0; i < outerBound; i = i + outerStep) { + for (unsigned int i = 0; i < outerBound; i = i + outerStep) { Value indexOp_i = rewriter.create(loc, i); Value valueRow; @@ -200,7 +199,7 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, matLoad.push_back(valueRow); } - for (int j = 0, k = 0; j < innerBound; j = j + innerStep) { + for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) { Value indexOp_j = rewriter.create(loc, j); Value valueRow; @@ -225,7 +224,7 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, loc, VectorType::get(innerVectSize, elementType), innerMatrix, index); } - for (int i = 0; i < fmaBound; i = i + 1) { + for (unsigned int i = 0; i < fmaBound; i = i + 1) { auto fmaOdd = rewriter.create(loc, matLoad[i], valueRow, acc[k]); k++; @@ -237,14 +236,14 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, } Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType, - SmallVector FMAs, Value accVec, int64_t vecSize, - int64_t M, int64_t N) { + SmallVector FMAs, Value accVec, unsigned int vecSize, + unsigned int M, unsigned int N) { auto strides = rewriter.getI64ArrayAttr({1}); if ((N / vecSize) > M) { - for (int j = 0, k = 0; j < (N / vecSize); j++) { - for (int i = 0; i < M; i++) { - int64_t off = (j * vecSize) + (i * N); + for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) { + for (unsigned int i = 0; i < M; i++) { + unsigned int off = (j * vecSize) + (i * N); auto offsets = rewriter.getI64ArrayAttr({off}); accVec = rewriter.create( loc, vecType, FMAs[k], accVec, offsets, strides); @@ -253,7 +252,7 @@ Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType, } } else { - for (int i = 0, k = 0; i < M * N; i = i + vecSize) { + for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) { auto offsets = rewriter.getI64ArrayAttr({i}); accVec = rewriter.create( loc, vecType, FMAs[k], accVec, offsets, strides); @@ -267,8 +266,10 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, vector::TransferReadOp vectorReadOpLhs, vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp, Type elementType, - int64_t vectorSize, int64_t vnni, int64_t M, int64_t N, - ValueRange iterArgsNewReductionForOp, int64_t dimCount) { + unsigned int vectorSize, unsigned int vnni, + unsigned int M, unsigned int N, + ValueRange iterArgsNewReductionForOp, + unsigned int dimCount) { auto newKForOp = rewriter.create( kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), kForOp.getStep(), iterArgsNewReductionForOp, @@ -304,8 +305,10 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, vector::TransferReadOp vectorReadOpLhs, vector::TransferReadOp vectorReadOpRhs, Type elementType, - int64_t vectorSize, int64_t vnni, int64_t M, int64_t N, - ValueRange iterArgsNewReductionForOp, int64_t dimCount) { + unsigned int vectorSize, unsigned int vnni, + unsigned int M, unsigned int N, + ValueRange iterArgsNewReductionForOp, + unsigned int dimCount) { auto newKForOp = rewriter.create( kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), @@ -335,166 +338,174 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, return newKForOp; } -//LogicalResult mlir::x86vector::nanoKernels(RewriterBase &rewriter, - // vector::ContractionOp contractOp)//, int64_t vectorSize) { - -struct VectorContractNanokernelLowering final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct VectorContractNanokernelLowering + : public OpRewritePattern { + VectorContractNanokernelLowering(MLIRContext *context, + std::optional vecSize) + : OpRewritePattern(context), + userVectorSize(vecSize) {} LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - auto loc = contractOp.getLoc(); - int64_t vectorSize = 16; - - if (contractOp.getKind() != vector::CombiningKind::ADD) { - return rewriter.notifyMatchFailure(contractOp, - "Expects add combining kind"); - } - - auto dimCount = contractOp.getRhsType().getRank() + 1; - if ((dimCount != 3) && (dimCount != 4)) - return rewriter.notifyMatchFailure(contractOp, - "Expects batch-reduce or batch matmuls"); + auto loc = contractOp.getLoc(); - // Get the M, N, K, and batch-reduce loops - auto loops = getNestedLoop(contractOp, dimCount); - if (failed(loops)) - return rewriter.notifyMatchFailure(contractOp, - "Invalid loop nest in contract pattern"); + unsigned int vectorSize = 8; - auto nestedLoops = *loops; - scf::ForOp kForOp = nestedLoops[0]; - scf::ForOp reductionForOp; + if (userVectorSize) + vectorSize = *userVectorSize; - vector::TransferReadOp vectorReadOpAcc; + if (contractOp.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + } - if (dimCount == 4) { - reductionForOp = nestedLoops[1]; - vectorReadOpAcc = - reductionForOp.getInitArgs()[0].getDefiningOp(); - } + auto dimCount = contractOp.getRhsType().getRank() + 1; - if (dimCount == 3) { - vectorReadOpAcc = - kForOp.getInitArgs()[0].getDefiningOp(); - } + if ((dimCount != 3) && (dimCount != 4)) + return rewriter.notifyMatchFailure( + contractOp, "Expects batch-reduce or batch matmuls"); - auto vectorReadOpLhs = - contractOp.getLhs().getDefiningOp(); - auto vectorReadOpRhs = - contractOp.getRhs().getDefiningOp(); + // Get the M, N, K, and batch-reduce loops + auto loops = getNestedLoop(contractOp, dimCount); + if (failed(loops)) + return rewriter.notifyMatchFailure( + contractOp, "Invalid loop nest in contract pattern"); - if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs) - return failure(); + auto nestedLoops = *loops; + scf::ForOp kForOp = nestedLoops[0]; + scf::ForOp reductionForOp; - auto subviewOpAcc = - vectorReadOpAcc.getOperand(0).getDefiningOp(); - auto subviewOpLhs = - vectorReadOpLhs.getOperand(0).getDefiningOp(); - auto subviewOpRhs = - vectorReadOpRhs.getOperand(0).getDefiningOp(); + vector::TransferReadOp vectorReadOpAcc; - if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs) - return failure(); - - SmallVector subviews; - subviews.push_back(subviewOpLhs); - subviews.push_back(subviewOpRhs); - subviews.push_back(subviewOpAcc); + if (dimCount == 4) { + reductionForOp = nestedLoops[1]; + vectorReadOpAcc = reductionForOp.getInitArgs()[0] + .getDefiningOp(); + } - // The M, N, K, and batch-reduce loop iv should match the iv's - // used in the subviews - auto checkLoops = checkNestedLoop(*loops, subviews, dimCount); - if (failed(checkLoops)) - return rewriter.notifyMatchFailure( - contractOp, "Loops doesn't match the iv in subviews"); + if (dimCount == 3) { + vectorReadOpAcc = + kForOp.getInitArgs()[0].getDefiningOp(); + } - auto elementType = - (cast(subviewOpLhs.getType())).getElementType(); + auto vectorReadOpLhs = + contractOp.getLhs().getDefiningOp(); + auto vectorReadOpRhs = + contractOp.getRhs().getDefiningOp(); - // TODO: Support for BF16 Type - if (!elementType.isF32()) - return rewriter.notifyMatchFailure( - contractOp, "Only, FP32 type is supported"); + if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs) + return failure(); - auto lhsType = dyn_cast(vectorReadOpLhs.getType()); - auto rhsType = dyn_cast(vectorReadOpRhs.getType()); + auto subviewOpAcc = + vectorReadOpAcc.getOperand(0).getDefiningOp(); + auto subviewOpLhs = + vectorReadOpLhs.getOperand(0).getDefiningOp(); + auto subviewOpRhs = + vectorReadOpRhs.getOperand(0).getDefiningOp(); - // Get M, N, and K dimension size - int64_t M = lhsType.getDimSize(lhsType.getRank() - 2); - int64_t N = rhsType.getDimSize(rhsType.getRank() - 1); - int64_t K = lhsType.getDimSize(lhsType.getRank() - 1); - int64_t vnni = 1; + if (!subviewOpAcc || !subviewOpLhs || !subviewOpRhs) + return failure(); - if (K != 1) - return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1"); + SmallVector subviews; + subviews.push_back(subviewOpLhs); + subviews.push_back(subviewOpRhs); + subviews.push_back(subviewOpAcc); + + // The M, N, K, and batch-reduce loop iv should match the iv's + // used in the subviews + auto checkLoops = checkNestedLoop(*loops, subviews, dimCount); + if (failed(checkLoops)) + return rewriter.notifyMatchFailure( + contractOp, "Loops doesn't match the iv in subviews"); + + auto elementType = + (cast(subviewOpLhs.getType())).getElementType(); + + // TODO: Support for BF16 Type + if (!elementType.isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only, FP32 type is supported"); + + auto lhsType = dyn_cast(vectorReadOpLhs.getType()); + auto rhsType = dyn_cast(vectorReadOpRhs.getType()); + + // Get M, N, and K dimension size + unsigned int M = lhsType.getDimSize(lhsType.getRank() - 2); + unsigned int N = rhsType.getDimSize(rhsType.getRank() - 1); + unsigned int K = lhsType.getDimSize(lhsType.getRank() - 1); + unsigned int vnni = 1; + + if (K != 1) + return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1"); + + if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1) + return rewriter.notifyMatchFailure(contractOp, + "The reduction-dim should be 1"); + + if (dimCount == 4) + rewriter.setInsertionPoint(reductionForOp); + + if (dimCount == 3) + rewriter.setInsertionPoint(kForOp); + + // Load MxN C sub matrix into acc vectors (e.g, ) + SmallVector loopItrArgs = + loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc); + + // Create the batch-reduce and K-loop with acc vectors as the loop + // iterargs (batch-reduce matmul) + nanokernel generation + scf::ForOp newLoop; + if (dimCount == 4) { + newLoop = rewriter.create( + reductionForOp.getLoc(), reductionForOp.getLowerBound(), + reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs, + [&](OpBuilder &rewriterNewReductionForOp, + Location locNewReductionForOp, Value ivNewReductionForOp, + ValueRange iterArgsNewReductionForOp) { + scf::ForOp newKForOp = createLoop( + rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, + ivNewReductionForOp, elementType, vectorSize, vnni, M, N, + iterArgsNewReductionForOp, dimCount); + + rewriterNewReductionForOp.create( + locNewReductionForOp, newKForOp.getResults()); + }); + } - if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1) - return rewriter.notifyMatchFailure(contractOp, - "The reduction-dim should be 1"); + // Create only the K-loop (batch matmul) + nanokernel generation + if (dimCount == 3) { + newLoop = createLoop(rewriter, loc, kForOp, vectorReadOpLhs, + vectorReadOpRhs, elementType, vectorSize, vnni, M, N, + loopItrArgs, dimCount); + } + // Combine all acc vectors into a MxN C matrix + auto vecType = VectorType::get({M * N}, rewriter.getF32Type()); + auto zeroAttr = + DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0)); + Value accVec = rewriter.create(loc, vecType, zeroAttr); - if (dimCount == 4) - rewriter.setInsertionPoint(reductionForOp); + accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, + vectorSize, M, N); - if (dimCount == 3) - rewriter.setInsertionPoint(kForOp); + auto accTy = dyn_cast(contractOp.getAccType()); + auto reshapeAcc = rewriter.create(loc, accTy, accVec); - // Load MxN C sub matrix into acc vectors (e.g, ) - SmallVector loopItrArgs = - loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc); + // Replace all the use of vector.contract with results of nanokernels + if (dimCount == 4) + rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc); - // Create the batch-reduce and K-loop with acc vectors as the loop - // iterargs (batch-reduce matmul) + nanokernel generation - scf::ForOp newLoop; - if (dimCount == 4) { - newLoop = rewriter.create( - reductionForOp.getLoc(), reductionForOp.getLowerBound(), - reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs, - [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, - Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) { - scf::ForOp newKForOp = createLoop( - rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, - ivNewReductionForOp, elementType, vectorSize, vnni, M, N, - iterArgsNewReductionForOp, dimCount); - - rewriterNewReductionForOp.create( - locNewReductionForOp, newKForOp.getResults()); - }); - } + if (dimCount == 3) + rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc); - // Create only the K-loop (batch matmul) + nanokernel generation - if (dimCount == 3) { - newLoop = - createLoop(rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, - elementType, vectorSize, vnni, M, N, loopItrArgs, dimCount); + return success(); } - - - // Combine all acc vectors into a MxN C matrix - auto vecType = VectorType::get({M * N}, rewriter.getF32Type()); - auto zeroAttr = - DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0)); - Value accVec = rewriter.create(loc, vecType, zeroAttr); - - accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N); - - auto accTy = dyn_cast(contractOp.getAccType()); - auto reshapeAcc = rewriter.create(loc, accTy, accVec); - - // Replace all the use of vector.contract with results of nanokernels - if (dimCount == 4) - rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc); - - if (dimCount == 3) - rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc); - - return success(); -} + std::optional userVectorSize; }; - -void x86vector::populateVectorContractNanokernelLoweringPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +void x86vector::populateVectorContractNanokernelLoweringPatterns( + RewritePatternSet &patterns, std::optional userVectorSize) { + patterns.add(patterns.getContext(), + userVectorSize); } diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 3839172fd0b42..efcd09fc1b924 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" @@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + x86vector::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir new file mode 100644 index 0000000000000..78ff150bb776e --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +module { + func.func @fp32_vectorSize_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c96 = arith.constant 96 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.for %arg3 = %c0 to %c4 step %c4 { + scf.for %arg4 = %c0 to %c96 step %c96 { + %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>> + %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32> + %2 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args(%arg6 = %1) -> (vector<4x96xf32>) { + %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x96xf32>) { + %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 96] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>> + %4 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32> + %5 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x96xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x96xf32> + %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x96xf32> into vector<4x96xf32> + scf.yield %6 : vector<4x96xf32> + } + scf.yield %3 : vector<4x96xf32> + } + vector.transfer_write %2, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>> + } + } + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @fp32_vectorSize_16( +// CHECK-COUNT-24: vector.fma{{.*}}vector<16xf32> +// CHECK-NOT: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16 + } : !transform.any_op + transform.yield + } +} From ca52bdc8ee1f06622053b10e1c0c48b8c734372e Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Fri, 24 Oct 2025 09:09:37 -0700 Subject: [PATCH 5/7] rebase from the main branch --- .../X86Vector/Transforms/NanoKernels.cpp | 285 +++++++++++------- .../vector-contract-to-nanokernels.mlir | 45 +++ 2 files changed, 221 insertions(+), 109 deletions(-) diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp index 583334333a49d..c8270c96daeea 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp @@ -90,12 +90,12 @@ static LogicalResult checkNestedLoop(SmallVector loops, return success(); } -static SmallVector loadAcc(Location loc, RewriterBase &rewriter, - Type elementType, unsigned int M, - unsigned int N, unsigned int vectorSize, - Value subviewOpAcc) { +static SmallVector +loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter, + Type elementType, unsigned int M, unsigned int N, + unsigned int vectorSize, Value subviewOpAcc) { - SmallVector loopItrArgs; + SmallVector accumulators; unsigned int outerBound = M; unsigned int innerBound = N; @@ -112,34 +112,74 @@ static SmallVector loadAcc(Location loc, RewriterBase &rewriter, for (unsigned int i = 0; i < outerBound; i = i + outerStep) { for (unsigned int j = 0; j < innerBound; j = j + innerStep) { - Value indexOp_A = rewriter.create(loc, i); - Value indexOp_B = rewriter.create(loc, j); + Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i); + Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j); if ((N / vectorSize) > M) { indexOp_A = indexOp_B; - indexOp_B = rewriter.create(loc, i); + indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i); } - auto valueCRow = rewriter.create( - loc, VectorType::get(vectorSize, elementType), subviewOpAcc, + auto valueCRow = vector::LoadOp::create( + rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc, ValueRange{indexOp_A, indexOp_B}); - loopItrArgs.push_back(valueCRow); + accumulators.push_back(valueCRow); } } - return loopItrArgs; + return accumulators; } -SmallVector nanoKernels(RewriterBase &rewriter, Location loc, - Type elementType, unsigned int vectorSize, - unsigned int vnni, unsigned int M, - unsigned int N, ValueRange acc, Value matA, - Value matB, unsigned int dimCount) { - - SmallVector accVector; +// Function accepts A Matrix, B Matrix, C Matrix (as vectors) and generate +// equivalent target specific nanokernels. Returns the final accumulator as +// output. Based on M tile, N tile, and vector size it generated optimized +// nanokernels with condition of reduction and K dimension of the input matrix +// are 1. +// +// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile +// size, N tile size, Vector size. +// +// Output: +// case i: M > (N/vector size). For example, M=3; N=32; vector size = 16. +// load_B0 = load B[0-15] into vector<16xf32> +// load_B1 = load B[16-31] into vector<16xf32> +// bcst_A0 = load A[0] and broadcast it into vector<16xf32> +// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0] +// o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1] +// bcst_A1 = load A[1] and broadcast it into vector<16xf32> +// o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2] +// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3] +// bcst_A2 = load A[2] and broadcast it into vector<16xf32> +// o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4] +// o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5] +// +// case ii: M <= (N/vector size). For example, M=2; N=48; vector size = 16. +// bcst_A0 = load A[0] and broadcast it into vector<16xf32> +// bcst_A1 = load A[1] and broadcast it into vector<16xf32> +// bcst_A2 = load A[2] and broadcast it into vector<16xf32> +// load_B0 = load B[0-15] into vector<16xf32> +// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0] +// o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1] +// load_B1 = load B[16-31] into vector<16xf32> +// o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2] +// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3] +// load_B2 = load B[32-47] into vector<16xf32> +// o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4] +// o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5] +// +// return o/p_Acc; +SmallVector +generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, + unsigned int vectorSize, unsigned int vnni, unsigned int M, + unsigned int N, ValueRange acc, Value matA, Value matB, + unsigned int dimCount) { + + SmallVector accumulators; SmallVector matLoad; - Value c0 = rewriter.create(loc, 0); + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + // Start with assumption that M tile size is smaller and create the + // helper variables unsigned int outerBound = M; unsigned int outerStep = 1; @@ -154,6 +194,7 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, unsigned int fmaBound = M; + // update helper variables if N tile size is smaller if ((N / vectorSize) < M) { outerBound = N; innerBound = M; @@ -170,37 +211,52 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, fmaBound = N / vectorSize; } + // Load all the element of A or B matrix for (unsigned int i = 0; i < outerBound; i = i + outerStep) { - Value indexOp_i = rewriter.create(loc, i); + Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i); Value valueRow; if ((N / vectorSize) > M) { + // With the assumption as batch-reduce matmul initialize reduction, M, and + // K dimension. SmallVector index = {c0, indexOp_i, c0}; + + // Remove reduction dimension if it is a batch matmul if (dimCount == 3) { index.erase(index.begin()); } - Value row = rewriter.create( - loc, VectorType::get(outerVectSize, elementType), outerMatrix, index); - valueRow = rewriter.create( - loc, VectorType::get(vectorSize, rewriter.getF32Type()), row); + // A Matrix load + broadcast + Value row = vector::LoadOp::create( + rewriter, loc, VectorType::get(outerVectSize, elementType), + outerMatrix, index); + valueRow = vector::BroadcastOp::create( + rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()), + row); } else { + // With the assumption as batch-reduce matmul initialize reduction, K, and + // N dimension. SmallVector index = {c0, c0, indexOp_i}; + + // Remove reduction dimension if it is a batch matmul if (dimCount == 3) { index.erase(index.begin()); } - valueRow = rewriter.create( - loc, VectorType::get(outerVectSize, elementType), outerMatrix, index); + // B Matrix load. + valueRow = vector::LoadOp::create( + rewriter, loc, VectorType::get(outerVectSize, elementType), + outerMatrix, index); } matLoad.push_back(valueRow); } + // Load elements of A/B Matrix one at a time and compute FMA for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) { - Value indexOp_j = rewriter.create(loc, j); + Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j); Value valueRow; if ((N / vectorSize) < M) { @@ -208,11 +264,14 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, if (dimCount == 3) { index.erase(index.begin()); } - Value row = rewriter.create( - loc, VectorType::get(innerVectSize, elementType), innerMatrix, - ValueRange(index)); - valueRow = rewriter.create( - loc, VectorType::get(vectorSize, rewriter.getF32Type()), row); + + // A Matrix load + broadcast + Value row = vector::LoadOp::create( + rewriter, loc, VectorType::get(innerVectSize, elementType), + innerMatrix, ValueRange(index)); + valueRow = vector::BroadcastOp::create( + rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()), + row); } else { SmallVector index = {c0, c0, indexOp_j}; @@ -220,58 +279,35 @@ SmallVector nanoKernels(RewriterBase &rewriter, Location loc, index.erase(index.begin()); } - valueRow = rewriter.create( - loc, VectorType::get(innerVectSize, elementType), innerMatrix, index); + // B Matrix load + valueRow = vector::LoadOp::create( + rewriter, loc, VectorType::get(innerVectSize, elementType), + innerMatrix, index); } + // FMAs for (unsigned int i = 0; i < fmaBound; i = i + 1) { auto fmaOdd = - rewriter.create(loc, matLoad[i], valueRow, acc[k]); + vector::FMAOp::create(rewriter, loc, matLoad[i], valueRow, acc[k]); k++; - accVector.push_back(fmaOdd); + accumulators.push_back(fmaOdd); } } - return accVector; -} - -Value accVector(RewriterBase &rewriter, Location loc, VectorType vecType, - SmallVector FMAs, Value accVec, unsigned int vecSize, - unsigned int M, unsigned int N) { - - auto strides = rewriter.getI64ArrayAttr({1}); - if ((N / vecSize) > M) { - for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) { - for (unsigned int i = 0; i < M; i++) { - unsigned int off = (j * vecSize) + (i * N); - auto offsets = rewriter.getI64ArrayAttr({off}); - accVec = rewriter.create( - loc, vecType, FMAs[k], accVec, offsets, strides); - k++; - } - } - - } else { - for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) { - auto offsets = rewriter.getI64ArrayAttr({i}); - accVec = rewriter.create( - loc, vecType, FMAs[k], accVec, offsets, strides); - k++; - } - } - return accVec; + return accumulators; } -scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, - vector::TransferReadOp vectorReadOpLhs, - vector::TransferReadOp vectorReadOpRhs, - Value ivNewReductionForOp, Type elementType, - unsigned int vectorSize, unsigned int vnni, - unsigned int M, unsigned int N, - ValueRange iterArgsNewReductionForOp, - unsigned int dimCount) { - auto newKForOp = rewriter.create( - kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), +// Function to re-create K dimension loop with accumulator as IterArgs for +// lowering a batch-reduce vector contraction to a system specific nanokernels. +scf::ForOp createGEMMLoopsWithAccAsIterArgs( + RewriterBase &rewriter, Location loc, scf::ForOp kForOp, + vector::TransferReadOp vectorReadOpLhs, + vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp, + Type elementType, unsigned int vectorSize, unsigned int vnni, + unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp, + unsigned int dimCount) { + auto newKForOp = scf::ForOp::create( + rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), kForOp.getStep(), iterArgsNewReductionForOp, [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, Value ivNewKForOp, ValueRange iterArgsNewKForOp) { @@ -291,27 +327,28 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, auto rhsClone = rewriterNewKForOp.clone( *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping); - auto evenFMAs = - nanoKernels(rewriter, kForOp.getLoc(), elementType, vectorSize, - vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0), - rhsClone->getResult(0), dimCount); + auto evenFMAs = generateNanokernels( + rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N, + iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0), + dimCount); - rewriterNewKForOp.create(locNewKForOp, evenFMAs); + scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs); }); return newKForOp; } -scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, - vector::TransferReadOp vectorReadOpLhs, - vector::TransferReadOp vectorReadOpRhs, Type elementType, - unsigned int vectorSize, unsigned int vnni, - unsigned int M, unsigned int N, - ValueRange iterArgsNewReductionForOp, - unsigned int dimCount) { - - auto newKForOp = rewriter.create( - kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), +// Function to re-create K dimension loop with accumulator as IterArgs for +// lowering a batch vector contraction to a system specific nanokernels. +scf::ForOp createGEMMLoopsWithAccAsIterArgs( + RewriterBase &rewriter, Location loc, scf::ForOp kForOp, + vector::TransferReadOp vectorReadOpLhs, + vector::TransferReadOp vectorReadOpRhs, Type elementType, + unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N, + ValueRange iterArgsNewReductionForOp, unsigned int dimCount) { + + auto newKForOp = scf::ForOp::create( + rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), kForOp.getStep(), iterArgsNewReductionForOp, [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, Value ivNewKForOp, ValueRange iterArgsNewKForOp) { @@ -328,16 +365,45 @@ scf::ForOp createLoop(RewriterBase &rewriter, Location loc, scf::ForOp kForOp, *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping); auto evenFMAs = - nanoKernels(rewriter, loc, elementType, vectorSize, vnni, M, N, - iterArgsNewKForOp, lhsClone->getResult(0), - rhsClone->getResult(0), dimCount); + generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M, + N, iterArgsNewKForOp, lhsClone->getResult(0), + rhsClone->getResult(0), dimCount); - rewriterNewKForOp.create(locNewKForOp, evenFMAs); + scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs); }); return newKForOp; } +Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc, + VectorType vecType, + SmallVector FMAs, Value accVec, + unsigned int vecSize, unsigned int M, + unsigned int N) { + + auto strides = rewriter.getI64ArrayAttr({1}); + if ((N / vecSize) > M) { + for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) { + for (unsigned int i = 0; i < M; i++) { + unsigned int off = (j * vecSize) + (i * N); + auto offsets = rewriter.getI64ArrayAttr({off}); + accVec = vector::InsertStridedSliceOp::create( + rewriter, loc, vecType, FMAs[k], accVec, offsets, strides); + k++; + } + } + + } else { + for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) { + auto offsets = rewriter.getI64ArrayAttr({i}); + accVec = vector::InsertStridedSliceOp::create( + rewriter, loc, vecType, FMAs[k], accVec, offsets, strides); + k++; + } + } + return accVec; +} + struct VectorContractNanokernelLowering : public OpRewritePattern { VectorContractNanokernelLowering(MLIRContext *context, @@ -450,47 +516,48 @@ struct VectorContractNanokernelLowering rewriter.setInsertionPoint(kForOp); // Load MxN C sub matrix into acc vectors (e.g, ) - SmallVector loopItrArgs = - loadAcc(loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc); + SmallVector accumulators = loadAccumulatorBeforeGEMM( + loc, rewriter, elementType, M, N, vectorSize, subviewOpAcc); // Create the batch-reduce and K-loop with acc vectors as the loop // iterargs (batch-reduce matmul) + nanokernel generation scf::ForOp newLoop; if (dimCount == 4) { - newLoop = rewriter.create( - reductionForOp.getLoc(), reductionForOp.getLowerBound(), - reductionForOp.getUpperBound(), reductionForOp.getStep(), loopItrArgs, + newLoop = scf::ForOp::create( + rewriter, reductionForOp.getLoc(), reductionForOp.getLowerBound(), + reductionForOp.getUpperBound(), reductionForOp.getStep(), + accumulators, [&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, Value ivNewReductionForOp, ValueRange iterArgsNewReductionForOp) { - scf::ForOp newKForOp = createLoop( + scf::ForOp newKForOp = createGEMMLoopsWithAccAsIterArgs( rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, ivNewReductionForOp, elementType, vectorSize, vnni, M, N, iterArgsNewReductionForOp, dimCount); - rewriterNewReductionForOp.create( - locNewReductionForOp, newKForOp.getResults()); + scf::YieldOp::create(rewriterNewReductionForOp, + locNewReductionForOp, newKForOp.getResults()); }); } // Create only the K-loop (batch matmul) + nanokernel generation if (dimCount == 3) { - newLoop = createLoop(rewriter, loc, kForOp, vectorReadOpLhs, - vectorReadOpRhs, elementType, vectorSize, vnni, M, N, - loopItrArgs, dimCount); + newLoop = createGEMMLoopsWithAccAsIterArgs( + rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, elementType, + vectorSize, vnni, M, N, accumulators, dimCount); } // Combine all acc vectors into a MxN C matrix auto vecType = VectorType::get({M * N}, rewriter.getF32Type()); auto zeroAttr = DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0)); - Value accVec = rewriter.create(loc, vecType, zeroAttr); + Value accVec = arith::ConstantOp::create(rewriter, loc, vecType, zeroAttr); - accVec = accVector(rewriter, loc, vecType, newLoop.getResults(), accVec, - vectorSize, M, N); + accVec = mergeAccumulatedVectorAsMatrix( + rewriter, loc, vecType, newLoop.getResults(), accVec, vectorSize, M, N); auto accTy = dyn_cast(contractOp.getAccType()); - auto reshapeAcc = rewriter.create(loc, accTy, accVec); + auto reshapeAcc = vector::ShapeCastOp::create(rewriter, loc, accTy, accVec); // Replace all the use of vector.contract with results of nanokernels if (dimCount == 4) diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir index 78ff150bb776e..184ba346e8638 100644 --- a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir @@ -46,3 +46,48 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +module { + func.func @fp32_batch_matmul_vector_size_8(%arg0: memref<4x32xf32>, %arg1: memref<32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c96 = arith.constant 96 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.for %arg3 = %c0 to %c4 step %c4 { + scf.for %arg4 = %c0 to %c96 step %c96 { + %subview = memref.subview %arg2[%arg3, %arg4] [4, 96] [1, 1] : memref<4x96xf32> to memref<4x96xf32, strided<[96, 1], offset: ?>> + %1 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32, strided<[96, 1], offset: ?>>, vector<4x96xf32> + + %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %1) -> (vector<4x96xf32>) { + %subview_0 = memref.subview %arg0[%arg3, %arg7] [4, 1] [1, 1] : memref<4x32xf32> to memref<4x1xf32, strided<[32, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg7, %arg4] [1, 96] [1, 1] : memref<32x96xf32> to memref<1x96xf32, strided<[96, 1], offset: ?>> + %4 = vector.transfer_read %subview_0[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x1xf32, strided<[32, 1], offset: ?>>, vector<4x1xf32> + %5 = vector.transfer_read %subview_1[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x96xf32, strided<[96, 1], offset: ?>>, vector<1x96xf32> + %6 = vector.contract {indexing_maps = [affine_map<(d1, d2, d3) -> (d1, d3)>, affine_map<(d1, d2, d3) -> (d3, d2)>, affine_map<(d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<4x1xf32>, vector<1x96xf32> into vector<4x96xf32> + scf.yield %6 : vector<4x96xf32> + } + + vector.transfer_write %3, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32, strided<[96, 1], offset: ?>> + } + } + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @fp32_batch_matmul_vector_size_8( +// CHECK-COUNT-48: vector.fma{{.*}}vector<8xf32> +// CHECK-NOT: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 8 + } : !transform.any_op + transform.yield + } +} From f34f50d987593f182dcd6e772000855549f5607e Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Mon, 27 Oct 2025 03:34:41 -0700 Subject: [PATCH 6/7] added more commnets + more test-cases --- .../mlir/Dialect/X86Vector/Transforms.h | 5 +- .../X86Vector/Transforms/NanoKernels.cpp | 187 +++++++++++++----- .../vector-contract-to-nanokernels.mlir | 126 +++++++++++- 3 files changed, 262 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index a9487adba002a..6410c12265f12 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -84,7 +84,10 @@ struct MaskHelper { }; //===----------------------------------------------------------------------===// -// Nano-kernels +// Transforms a scheduled pattern to lower a tiled batch or batch-reduce +// vector contraction into a sequence of nanokernels. +// The transformation is tailored to the target machine architecture +// and guided by the user-specified vector size. void populateVectorContractNanokernelLoweringPatterns( RewritePatternSet &patterns, std::optional vectorSize = 8); diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp index c8270c96daeea..4d0906a2ec057 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp @@ -7,7 +7,8 @@ //===----------------------------------------------------------------------===// // // This file implements matmul rewrites as nanokernels with respect to target -// machine for FP32 and BF16 (TODO) types. +// machine for FP32 (for selective batch or batch-reduce matmul patterns) and +// BF16 (TODO) types. // //===----------------------------------------------------------------------===// @@ -31,12 +32,18 @@ using namespace mlir; using namespace mlir::vector; using namespace mlir::x86vector; +// Enum to represent the type of matmul operation +enum class MatMulType { Batch, BatchReduce, Others }; + static FailureOr> -getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) { +getTiledMatmulLoopNest(vector::ContractionOp contractOp, + MatMulType matmulType) { SmallVector list; Operation *current = contractOp; - // It is register tiled loop structure on batch reduce matmul - // (M->N->Batch-reduce->K). + unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3; + + // It is register tiled loop structure on batch (or reduce) matmul + // (M->N->(reduce)->K). for (unsigned int i = 0; i < dimCount; i++) { Operation *parent = current->getParentOfType(); if (!parent) @@ -47,14 +54,14 @@ getNestedLoop(vector::ContractionOp contractOp, unsigned int dimCount) { return list; } -static LogicalResult checkNestedLoop(SmallVector loops, - SmallVector subviews, - unsigned int dimCount) { +static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching( + SmallVector loops, SmallVector subviews, + MatMulType matmulType) { auto subviewOpLhsOffsets = subviews[0].getOffsets(); auto subviewOpRhsOffsets = subviews[1].getOffsets(); auto subviewOpAccOffsets = subviews[2].getOffsets(); - if (dimCount == 4) { + if (matmulType == MatMulType::BatchReduce) { Value ivK = loops[0].getInductionVar(); if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1]) return failure(); @@ -73,7 +80,7 @@ static LogicalResult checkNestedLoop(SmallVector loops, return failure(); } - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { Value ivK = loops[0].getInductionVar(); if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0]) return failure(); @@ -96,13 +103,16 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter, unsigned int vectorSize, Value subviewOpAcc) { SmallVector accumulators; + + // Initialize local variable on assumption that M tile is larger than N unsigned int outerBound = M; unsigned int innerBound = N; unsigned int outerStep = 1; unsigned int innerStep = vectorSize; - if ((N / vectorSize) > M) { + bool isNTileLarge = (N / vectorSize) > M; + if (isNTileLarge) { outerBound = N; innerBound = M; @@ -115,7 +125,7 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter, Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i); Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j); - if ((N / vectorSize) > M) { + if (isNTileLarge) { indexOp_A = indexOp_B; indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i); } @@ -130,17 +140,18 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter, return accumulators; } -// Function accepts A Matrix, B Matrix, C Matrix (as vectors) and generate -// equivalent target specific nanokernels. Returns the final accumulator as -// output. Based on M tile, N tile, and vector size it generated optimized -// nanokernels with condition of reduction and K dimension of the input matrix -// are 1. +// This function takes matrices A, B, and C (represented as vectors) +// and generates equivalent target-specific nanokernels. +// It returns the final accumulator as output. +// Based on the M tile, N tile, and vector size, it generates optimized +// nanokernels under the condition that the reduction and K dimension +// of the input matrices are equal to 1. // // Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile // size, N tile size, Vector size. // // Output: -// case i: M > (N/vector size). For example, M=3; N=32; vector size = 16. +// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16. // load_B0 = load B[0-15] into vector<16xf32> // load_B1 = load B[16-31] into vector<16xf32> // bcst_A0 = load A[0] and broadcast it into vector<16xf32> @@ -153,7 +164,7 @@ loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter, // o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4] // o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5] // -// case ii: M <= (N/vector size). For example, M=2; N=48; vector size = 16. +// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16. // bcst_A0 = load A[0] and broadcast it into vector<16xf32> // bcst_A1 = load A[1] and broadcast it into vector<16xf32> // bcst_A2 = load A[2] and broadcast it into vector<16xf32> @@ -172,7 +183,7 @@ SmallVector generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N, ValueRange acc, Value matA, Value matB, - unsigned int dimCount) { + MatMulType matmulType) { SmallVector accumulators; SmallVector matLoad; @@ -195,7 +206,8 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, unsigned int fmaBound = M; // update helper variables if N tile size is smaller - if ((N / vectorSize) < M) { + bool isNTileLarge = (N / vectorSize) > M; + if (!isNTileLarge) { outerBound = N; innerBound = M; @@ -216,14 +228,14 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i); Value valueRow; - if ((N / vectorSize) > M) { + if (isNTileLarge) { // With the assumption as batch-reduce matmul initialize reduction, M, and // K dimension. SmallVector index = {c0, indexOp_i, c0}; // Remove reduction dimension if it is a batch matmul - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { index.erase(index.begin()); } @@ -241,7 +253,7 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, SmallVector index = {c0, c0, indexOp_i}; // Remove reduction dimension if it is a batch matmul - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { index.erase(index.begin()); } @@ -259,9 +271,9 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j); Value valueRow; - if ((N / vectorSize) < M) { + if (!isNTileLarge) { SmallVector index = {c0, indexOp_j, c0}; - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { index.erase(index.begin()); } @@ -275,7 +287,7 @@ generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType, } else { SmallVector index = {c0, c0, indexOp_j}; - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { index.erase(index.begin()); } @@ -305,7 +317,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs( vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp, Type elementType, unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp, - unsigned int dimCount) { + MatMulType matmulType) { auto newKForOp = scf::ForOp::create( rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), kForOp.getStep(), iterArgsNewReductionForOp, @@ -330,7 +342,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs( auto evenFMAs = generateNanokernels( rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0), - dimCount); + matmulType); scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs); }); @@ -345,7 +357,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs( vector::TransferReadOp vectorReadOpLhs, vector::TransferReadOp vectorReadOpRhs, Type elementType, unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N, - ValueRange iterArgsNewReductionForOp, unsigned int dimCount) { + ValueRange iterArgsNewReductionForOp, MatMulType matmulType) { auto newKForOp = scf::ForOp::create( rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(), @@ -367,7 +379,7 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs( auto evenFMAs = generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M, N, iterArgsNewKForOp, lhsClone->getResult(0), - rhsClone->getResult(0), dimCount); + rhsClone->getResult(0), matmulType); scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs); }); @@ -378,14 +390,15 @@ scf::ForOp createGEMMLoopsWithAccAsIterArgs( Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc, VectorType vecType, SmallVector FMAs, Value accVec, - unsigned int vecSize, unsigned int M, + unsigned int vectorSize, unsigned int M, unsigned int N) { auto strides = rewriter.getI64ArrayAttr({1}); - if ((N / vecSize) > M) { - for (unsigned int j = 0, k = 0; j < (N / vecSize); j++) { + bool isNTileLarge = (N / vectorSize) > M; + if (isNTileLarge) { + for (unsigned int j = 0, k = 0; j < (N / vectorSize); j++) { for (unsigned int i = 0; i < M; i++) { - unsigned int off = (j * vecSize) + (i * N); + unsigned int off = (j * vectorSize) + (i * N); auto offsets = rewriter.getI64ArrayAttr({off}); accVec = vector::InsertStridedSliceOp::create( rewriter, loc, vecType, FMAs[k], accVec, offsets, strides); @@ -394,7 +407,7 @@ Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc, } } else { - for (unsigned int i = 0, k = 0; i < M * N; i = i + vecSize) { + for (unsigned int i = 0, k = 0; i < M * N; i = i + vectorSize) { auto offsets = rewriter.getI64ArrayAttr({i}); accVec = vector::InsertStridedSliceOp::create( rewriter, loc, vecType, FMAs[k], accVec, offsets, strides); @@ -404,6 +417,53 @@ Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc, return accVec; } +// Rewriter pattern for vector.contract operation. +// Input: vector.contract with tiled dimensions (batch or batch-matmul) +// Matching Pattern: +// scf.for (0 to M) step m_tile { +// scf.for (0 to N) step n_tile { +// - Subview of Accumulator matrix - eg., acc : memref +// - %read = vector.transfer_read memref to +// vector %1 = scf.for (0 to reduce) +// iter_args_reduce=%read step reduce_tile { +// %2 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile { +// - Subview of A and B matrix +// - Vector transfer read of A and B +// - %acc = Vector.contract %read_A %read_B %iter_args_k +// scf.yield %acc +// } +// scf.yield %2 +// } +// vector.transfer_write %2 into accmulator matrix +// } +// } +// +// +// Rewrite IR: +// scf.for (0 to M) step m_tile { +// scf.for (0 to N) step n_tile { +// - Subview of Accumulator matrix - eg., acc : memref +// - %a = (n_tile / vector_size) * m_tile; +// // load the accumulator matrix as vector +// - %0 = load acc[0][0-15] into vector<16xf32> +// - %1 = load acc[0][16-31] into vector<16xf32> +// - %2 = load acc[1][0-15] into vector<16xf32> +// . +// . +// . +// - %a = load acc[m_tile-1][*-n_tile-1] into vector<16xf32> +// %3 = scf.for (0 to reduce) iter_args_reduce=%0 to %a step reduce_tile { +// %4 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile { +// - emit nano kernels (as shown in commnets above +// generateNanokernels function) +// scf.yield %acc[0] to %acc[a-1] +// } +// scf.yield %4: [0] to [a-1] +// } +// %5 = vector.insert %3: [0] to [a-1] into vector +// vector.transfer_write %5 into accmulator matrix +// } +// } struct VectorContractNanokernelLowering : public OpRewritePattern { VectorContractNanokernelLowering(MLIRContext *context, @@ -417,7 +477,6 @@ struct VectorContractNanokernelLowering auto loc = contractOp.getLoc(); unsigned int vectorSize = 8; - if (userVectorSize) vectorSize = *userVectorSize; @@ -426,14 +485,28 @@ struct VectorContractNanokernelLowering "Expects add combining kind"); } - auto dimCount = contractOp.getRhsType().getRank() + 1; + SmallVector contractIteratorTypes = + contractOp.getIteratorTypesArray(); + + unsigned int reductionCount = + std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(), + vector::IteratorType::reduction); - if ((dimCount != 3) && (dimCount != 4)) + MatMulType matmulType = MatMulType::Others; + + if (reductionCount == 1) + matmulType = MatMulType::Batch; + + if (reductionCount == 2) + matmulType = MatMulType::BatchReduce; + + if ((matmulType != MatMulType::BatchReduce) && + (matmulType != MatMulType::Batch)) return rewriter.notifyMatchFailure( contractOp, "Expects batch-reduce or batch matmuls"); // Get the M, N, K, and batch-reduce loops - auto loops = getNestedLoop(contractOp, dimCount); + auto loops = getTiledMatmulLoopNest(contractOp, matmulType); if (failed(loops)) return rewriter.notifyMatchFailure( contractOp, "Invalid loop nest in contract pattern"); @@ -442,15 +515,21 @@ struct VectorContractNanokernelLowering scf::ForOp kForOp = nestedLoops[0]; scf::ForOp reductionForOp; + if (contractOp.getAcc().getDefiningOp()) { + return rewriter.notifyMatchFailure( + contractOp, "The Accumulator matrix should be hoisted outside the K " + "or reduction loop"); + } + vector::TransferReadOp vectorReadOpAcc; - if (dimCount == 4) { + if (matmulType == MatMulType::BatchReduce) { reductionForOp = nestedLoops[1]; vectorReadOpAcc = reductionForOp.getInitArgs()[0] .getDefiningOp(); } - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { vectorReadOpAcc = kForOp.getInitArgs()[0].getDefiningOp(); } @@ -480,10 +559,11 @@ struct VectorContractNanokernelLowering // The M, N, K, and batch-reduce loop iv should match the iv's // used in the subviews - auto checkLoops = checkNestedLoop(*loops, subviews, dimCount); + auto checkLoops = + checkMatmulLoopAndSubviewOffsetsMatching(*loops, subviews, matmulType); if (failed(checkLoops)) return rewriter.notifyMatchFailure( - contractOp, "Loops doesn't match the iv in subviews"); + contractOp, "tiled loops doesn't match the iv in subviews"); auto elementType = (cast(subviewOpLhs.getType())).getElementType(); @@ -505,14 +585,15 @@ struct VectorContractNanokernelLowering if (K != 1) return rewriter.notifyMatchFailure(contractOp, "The k-dim should be 1"); - if (dimCount == 4 && lhsType.getDimSize(lhsType.getRank() - 3) != 1) + if (matmulType == MatMulType::BatchReduce && + lhsType.getDimSize(lhsType.getRank() - 3) != 1) return rewriter.notifyMatchFailure(contractOp, "The reduction-dim should be 1"); - if (dimCount == 4) + if (matmulType == MatMulType::BatchReduce) rewriter.setInsertionPoint(reductionForOp); - if (dimCount == 3) + if (matmulType == MatMulType::Batch) rewriter.setInsertionPoint(kForOp); // Load MxN C sub matrix into acc vectors (e.g, ) @@ -522,7 +603,7 @@ struct VectorContractNanokernelLowering // Create the batch-reduce and K-loop with acc vectors as the loop // iterargs (batch-reduce matmul) + nanokernel generation scf::ForOp newLoop; - if (dimCount == 4) { + if (matmulType == MatMulType::BatchReduce) { newLoop = scf::ForOp::create( rewriter, reductionForOp.getLoc(), reductionForOp.getLowerBound(), reductionForOp.getUpperBound(), reductionForOp.getStep(), @@ -533,7 +614,7 @@ struct VectorContractNanokernelLowering scf::ForOp newKForOp = createGEMMLoopsWithAccAsIterArgs( rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, ivNewReductionForOp, elementType, vectorSize, vnni, M, N, - iterArgsNewReductionForOp, dimCount); + iterArgsNewReductionForOp, matmulType); scf::YieldOp::create(rewriterNewReductionForOp, locNewReductionForOp, newKForOp.getResults()); @@ -541,13 +622,13 @@ struct VectorContractNanokernelLowering } // Create only the K-loop (batch matmul) + nanokernel generation - if (dimCount == 3) { + if (matmulType == MatMulType::Batch) { newLoop = createGEMMLoopsWithAccAsIterArgs( rewriter, loc, kForOp, vectorReadOpLhs, vectorReadOpRhs, elementType, - vectorSize, vnni, M, N, accumulators, dimCount); + vectorSize, vnni, M, N, accumulators, matmulType); } - // Combine all acc vectors into a MxN C matrix + // Combine all output accumulator vectors into a m_tilexn_tile C matrix auto vecType = VectorType::get({M * N}, rewriter.getF32Type()); auto zeroAttr = DenseElementsAttr::get(vecType, rewriter.getF32FloatAttr(0.0)); @@ -560,10 +641,10 @@ struct VectorContractNanokernelLowering auto reshapeAcc = vector::ShapeCastOp::create(rewriter, loc, accTy, accVec); // Replace all the use of vector.contract with results of nanokernels - if (dimCount == 4) + if (matmulType == MatMulType::BatchReduce) rewriter.replaceAllUsesWith(reductionForOp.getResult(0), reshapeAcc); - if (dimCount == 3) + if (matmulType == MatMulType::Batch) rewriter.replaceAllUsesWith(kForOp.getResult(0), reshapeAcc); return success(); diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir index 184ba346e8638..32620657bd52d 100644 --- a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir @@ -4,7 +4,7 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> module { - func.func @fp32_vectorSize_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + func.func @fp32_batch_reduce_matmul_vector_size_16(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { %0 = ub.poison : f32 %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -33,7 +33,7 @@ module { } } -// CHECK-LABEL: func.func @fp32_vectorSize_16( +// CHECK-LABEL: func.func @fp32_batch_reduce_matmul_vector_size_16( // CHECK-COUNT-24: vector.fma{{.*}}vector<16xf32> // CHECK-NOT: vector.contract @@ -91,3 +91,125 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +module { + func.func @not_tiled_no_rewriting(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %c0 = arith.constant 0 : index + %0 = ub.poison : f32 + %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x32xf32>, vector<1x4x32xf32> + %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x32x96xf32>, vector<1x32x96xf32> + %3 = vector.transfer_read %arg2[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x96xf32>, vector<4x96xf32> + %4 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x32xf32>, vector<1x32x96xf32> into vector<4x96xf32> + vector.transfer_write %4, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x96xf32>, memref<4x96xf32> + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @not_tiled_no_rewriting( +// CHECK-NOT: vector.fma{{.*}}vector<8xf32> +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 8 + } : !transform.any_op + transform.yield + } +} + +// ----- + +module { + func.func @tensor_type_no_rewriting(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %1 = scf.for %arg3 = %c0 to %c32 step %c4 iter_args(%arg4 = %arg2) -> (tensor<32x32xf32>) { + %2 = scf.for %arg5 = %c0 to %c32 step %c16 iter_args(%arg6 = %arg4) -> (tensor<32x32xf32>) { + %3 = scf.for %arg7 = %c0 to %c32 step %c1 iter_args(%arg8 = %arg6) -> (tensor<32x32xf32>) { + %4 = scf.for %arg9 = %c0 to %c32 step %c1 iter_args(%arg10 = %arg8) -> (tensor<32x32xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg7, %arg3, %arg9] [1, 4, 1] [1, 1, 1] : tensor<32x32x32xf32> to tensor<1x4x1xf32> + %extracted_slice_0 = tensor.extract_slice %arg1[%arg7, %arg9, %arg5] [1, 1, 16] [1, 1, 1] : tensor<32x32x32xf32> to tensor<1x1x16xf32> + %extracted_slice_1 = tensor.extract_slice %arg10[%arg3, %arg5] [4, 16] [1, 1] : tensor<32x32xf32> to tensor<4x16xf32> + %5 = vector.transfer_read %extracted_slice[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : tensor<1x4x1xf32>, vector<1x4x1xf32> + %6 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : tensor<1x1x16xf32>, vector<1x1x16xf32> + %7 = vector.transfer_read %extracted_slice_1[%c0, %c0], %0 {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> + %8 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %5, %6, %7 : vector<1x4x1xf32>, vector<1x1x16xf32> into vector<4x16xf32> + %9 = vector.transfer_write %8, %extracted_slice_1[%c0, %c0] {in_bounds = [true, true]} : vector<4x16xf32>, tensor<4x16xf32> + %inserted_slice = tensor.insert_slice %9 into %arg10[%arg3, %arg5] [4, 16] [1, 1] : tensor<4x16xf32> into tensor<32x32xf32> + scf.yield %inserted_slice : tensor<32x32xf32> + } + scf.yield %4 : tensor<32x32xf32> + } + scf.yield %3 : tensor<32x32xf32> + } + scf.yield %2 : tensor<32x32xf32> + } + return %1 : tensor<32x32xf32> + } +} + +// CHECK-LABEL: func.func @tensor_type_no_rewriting( +// CHECK-NOT: vector.fma{{.*}}vector<16xf32> +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16 + } : !transform.any_op + transform.yield + } +} + +// ----- + +module { + func.func @accumulator_not_hoisted_outside_K_or_reduction_loop(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + %0 = ub.poison : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c96 = arith.constant 96 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + scf.for %arg3 = %c0 to %c4 step %c4 { + scf.for %arg4 = %c0 to %c96 step %c32 { + %subview = memref.subview %arg2[%arg3, %arg4] [4, 32] [1, 1] : memref<4x96xf32> to memref<4x32xf32, strided<[96, 1], offset: ?>> + scf.for %arg5 = %c0 to %c1 step %c1 { + scf.for %arg6 = %c0 to %c32 step %c1 { + %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<1x4x32xf32> to memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg5, %arg6, %arg4] [1, 1, 32] [1, 1, 1] : memref<1x32x96xf32> to memref<1x1x32xf32, strided<[3072, 96, 1], offset: ?>> + %1 = vector.transfer_read %subview_0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[128, 32, 1], offset: ?>>, vector<1x4x1xf32> + %2 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x1x32xf32, strided<[3072, 96, 1], offset: ?>>, vector<1x1x32xf32> + %3 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]} : memref<4x32xf32, strided<[96, 1], offset: ?>>, vector<4x32xf32> + %4 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x4x1xf32>, vector<1x1x32xf32> into vector<4x32xf32> + vector.transfer_write %4, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<4x32xf32>, memref<4x32xf32, strided<[96, 1], offset: ?>> + } + } + } + } + return %arg2 : memref<4x96xf32> + } +} + +// CHECK-LABEL: func.func @accumulator_not_hoisted_outside_K_or_reduction_loop( +// CHECK-NOT: vector.fma{{.*}}vector<16xf32> +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_nanokernel_lowering vector_size = 16 + } : !transform.any_op + transform.yield + } +} From a3385e2377a42d039a61e799b46acdcecfb2d118 Mon Sep 17 00:00:00 2001 From: Arun Thangamani Date: Tue, 28 Oct 2025 02:16:55 -0700 Subject: [PATCH 7/7] renaming test-cases --- .../X86Vector/vector-contract-to-nanokernels.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir index 32620657bd52d..3514358633c0d 100644 --- a/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir @@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} { // ----- module { - func.func @not_tiled_no_rewriting(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + func.func @negative_not_tiled(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { %c0 = arith.constant 0 : index %0 = ub.poison : f32 %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} : memref<1x4x32xf32>, vector<1x4x32xf32> @@ -107,7 +107,7 @@ module { } } -// CHECK-LABEL: func.func @not_tiled_no_rewriting( +// CHECK-LABEL: func.func @negative_not_tiled( // CHECK-NOT: vector.fma{{.*}}vector<8xf32> // CHECK: vector.contract @@ -124,7 +124,7 @@ module attributes {transform.with_named_sequence} { // ----- module { - func.func @tensor_type_no_rewriting(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> { + func.func @negative_tensor_type(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = ub.poison : f32 %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index @@ -156,7 +156,7 @@ module { } } -// CHECK-LABEL: func.func @tensor_type_no_rewriting( +// CHECK-LABEL: func.func @negative_tensor_type( // CHECK-NOT: vector.fma{{.*}}vector<16xf32> // CHECK: vector.contract @@ -173,7 +173,7 @@ module attributes {transform.with_named_sequence} { // ----- module { - func.func @accumulator_not_hoisted_outside_K_or_reduction_loop(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { + func.func @negative_accumulator_not_hoisted_outside_K_or_reduction_loop(%arg0: memref<1x4x32xf32>, %arg1: memref<1x32x96xf32>, %arg2: memref<4x96xf32>) -> memref<4x96xf32> { %0 = ub.poison : f32 %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -200,7 +200,7 @@ module { } } -// CHECK-LABEL: func.func @accumulator_not_hoisted_outside_K_or_reduction_loop( +// CHECK-LABEL: func.func @negative_accumulator_not_hoisted_outside_K_or_reduction_loop( // CHECK-NOT: vector.fma{{.*}}vector<16xf32> // CHECK: vector.contract