From 8d27a5cf6f1fd081637577da2be743ade33d68d9 Mon Sep 17 00:00:00 2001 From: Daniel Hernandez-Juarez Date: Wed, 12 Mar 2025 08:13:57 -0400 Subject: [PATCH] Add Gemm+Elementwise+Gemm support --- .../mlir/Dialect/Rock/IR/CMakeLists.txt | 1 + .../mlir/Dialect/Rock/IR/GemmGemmSize.h | 40 ++ mlir/include/mlir/Dialect/Rock/IR/Rock.h | 8 +- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 11 +- .../Rock/IR/RockGemmGemmWrapperInterface.h | 26 ++ .../Rock/IR/RockGemmGemmWrapperInterface.td | 252 ++++++++++ mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 143 ++++-- mlir/lib/Dialect/Rock/IR/CMakeLists.txt | 2 + .../Rock/IR/RockAcceptingViewOpInterface.cpp | 2 +- .../lib/Dialect/Rock/IR/RockConvInterface.cpp | 2 +- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 190 ++++++-- .../Rock/IR/RockGemmGemmWrapperInterface.cpp | 22 + .../Rock/Transforms/AffixTuningParameters.cpp | 41 +- .../BufferizableOpInterfaceImpl.cpp | 2 + .../Rock/Transforms/GemmToGridwise.cpp | 289 +++++++----- .../Transforms/GridwiseGemmToBlockwise.cpp | 434 +++++++++--------- .../Dialect/Rock/Transforms/Regularize.cpp | 2 +- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 166 ++++--- .../Dialect/Rock/affix_tuning_params.mlir | 39 +- .../Rock/affix_tuning_params_invalid.mlir | 9 + mlir/test/Dialect/Rock/gemm_to_gridwise.mlir | 42 ++ ...owering_sort_dimensions_memory_layout.mlir | 2 +- mlir/test/e2e/CMakeLists.txt | 1 + mlir/test/e2e/PrAttentionF32.toml | 15 - mlir/test/e2e/PrGemmElementwiseGemmF32.cfg | 2 + mlir/test/e2e/PrGemmElementwiseGemmF32.toml | 41 ++ mlir/test/rocmlir-gen/attention-kernel.mlir | 1 - .../gemm-elementwise-gemm-kernel.mlir | 24 + mlir/tools/rocmlir-gen/rocmlir-gen.cpp | 291 +++++++++++- .../rocmlir-tuning-driver.cpp | 12 +- mlir/utils/performance/perfCommonUtils.py | 3 + mlir/utils/performance/perfRunner.py | 196 +++++++- mlir/utils/performance/reportUtils.py | 1 + mlir/utils/performance/tuningRunner.py | 7 +- 34 files changed, 1753 insertions(+), 566 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h create mode 100644 mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h create mode 100644 mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td create mode 100644 mlir/lib/Dialect/Rock/IR/RockGemmGemmWrapperInterface.cpp create mode 100644 mlir/test/e2e/PrGemmElementwiseGemmF32.cfg create mode 100644 mlir/test/e2e/PrGemmElementwiseGemmF32.toml create mode 100644 mlir/test/rocmlir-gen/gemm-elementwise-gemm-kernel.mlir diff --git a/mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt index 896a3c5ca631..64aaf50771db 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt @@ -20,6 +20,7 @@ mlir_tablegen(RockAccelTuningParamAttrInterface.h.inc -gen-attr-interface-decls) mlir_tablegen(RockAccelTuningParamAttrInterface.cpp.inc -gen-attr-interface-defs) add_public_tablegen_target(MLIRRockAccelTuningParamAttrInterfaceIncGen) +add_mlir_interface(RockGemmGemmWrapperInterface) add_mlir_interface(RockGemmWrapperInterface) add_mlir_interface(RockConvInterface) add_mlir_interface(RockAcceptingViewOpInterface) diff --git a/mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h b/mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h new file mode 100644 index 000000000000..b8ad9816fd40 --- /dev/null +++ b/mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h @@ -0,0 +1,40 @@ +//===--------- GemmGemmSize.h - utility struct for gemm+gemm ----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a utility struct, GemmGemmSize, that packages the sizes of +// gemm+gemm to ensure a cleaner API. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H +#define MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H + +#include + +namespace mlir { +namespace rock { + +/// Structure for holding the sizes of a matrix multiplication operation. +struct GemmGemmSize { + int64_t g; + int64_t m; + int64_t k; + int64_t n; + int64_t o; + + GemmGemmSize(int64_t g, int64_t m, int64_t k, int64_t n, int64_t o) + : g(g), m(m), k(k), n(n), o(o) {} + + bool operator==(const GemmGemmSize &other) { + return (g == other.g) && (m == other.m) && (k == other.k) && + (n == other.n) && (o == other.o); + } +}; +} // end namespace rock +} // end namespace mlir +#endif // MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H diff --git a/mlir/include/mlir/Dialect/Rock/IR/Rock.h b/mlir/include/mlir/Dialect/Rock/IR/Rock.h index 2dc77fceb654..f222f8dd7375 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/Rock.h +++ b/mlir/include/mlir/Dialect/Rock/IR/Rock.h @@ -38,6 +38,7 @@ class PatternRewriter; #include "mlir/Dialect/Rock/IR/RockTypes.h" #include "mlir/Dialect/Rock/IR/ConvolutionDims.h" +#include "mlir/Dialect/Rock/IR/GemmGemmSize.h" #include "mlir/Dialect/Rock/IR/GemmSize.h" namespace mlir { @@ -49,12 +50,6 @@ class FusionRoot : public TraitBase {}; } // namespace OpTrait } // namespace mlir -// Following ifdef could be used to change -// the attention operator to be a fused gemm-gemm -// kernel for debugging purposes. This will also -// adjust the test harness to verify the same as well -// #define ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - namespace mlir { namespace rock { //===----------------------------------------------------------------------===// @@ -87,6 +82,7 @@ constexpr int64_t maxHardwareWorkgroupSize = 1024; #include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.h" #include "mlir/Dialect/Rock/IR/RockConvInterface.h" +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockWriterOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index a33fd153a466..5508ec411301 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -10,6 +10,7 @@ #define ROCK_ATTRS include "mlir/Dialect/Rock/IR/RockBase.td" +include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td" include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td" include "mlir/Dialect/Rock/IR/RockTuningParamAttrInterface.td" include "mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td" @@ -63,11 +64,13 @@ def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>; def KernelTypeConvBwdWeight : I32EnumAttrCase<"ConvBwdWeight", 2>; def KernelTypeGemm : I32EnumAttrCase<"Gemm", 3>; def KernelTypeAttention : I32EnumAttrCase<"Attention", 4>; +def KernelTypeGemmElementwiseGemm : I32EnumAttrCase<"GemmElementwiseGemm", 5>; -def KernelType : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel", - [KernelTypeConv, KernelTypeConvBwdData, - KernelTypeConvBwdWeight, KernelTypeGemm, - KernelTypeAttention]>; +def KernelType + : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel", + [KernelTypeConv, KernelTypeConvBwdData, + KernelTypeConvBwdWeight, KernelTypeGemm, + KernelTypeAttention, KernelTypeGemmElementwiseGemm]>; /// TransformType def PassThrough : I32EnumAttrCase<"PassThrough", 0>; diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h b/mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h new file mode 100644 index 000000000000..3151c4dbf2d6 --- /dev/null +++ b/mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h @@ -0,0 +1,26 @@ +//===- RockGemmGemmWrapperInterface.h - ops that wrap rock.attention -*- C++ +//-*-===// +// +// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (c) 2025 Advanced Micro Devices INc. +//===----------------------------------------------------------------------===// +// +// This file defines RockGemmGemmWrapperInterface, which abstracts attention and +// gemm+gemm to allow code to operate on them generically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H +#define MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H + +#include "mlir/Dialect/Rock/IR/GemmGemmSize.h" +#include "mlir/IR/OpDefinition.h" + +#include "mlir/Dialect/Rock/IR/RockTypes.h" + +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h.inc" + +#endif // MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td b/mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td new file mode 100644 index 000000000000..659ae812d3e2 --- /dev/null +++ b/mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td @@ -0,0 +1,252 @@ +//===- RockGemmGemmWrapperInterface.td - ops that wrap rock.attention +//---------===// +// +// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (c) 2025 Advanced Micro Devices INc. +//===----------------------------------------------------------------------===// +// +// This file defines RockGemmGemmWrapperInterface, which abstracts attention and +// gemm+gemm and friends (conv+gemm, ...) to allow code to operate on them +// generically. +// +//===----------------------------------------------------------------------===// + +#ifndef ROCK_GEMM_GEMM_WRAPPER_INTERFACE +#define ROCK_GEMM_GEMM_WRAPPER_INTERFACE + +include "mlir/IR/OpBase.td" + +def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> { + let description = [{ + Interface to abstract away gemm+gemm-wrapping operators in the rock dialect, + which mainly include attention and gemm+gemm and friends that can be implemented + with flash attention. + + This should include functions to get common attributes. + }]; + let cppNamespace = "::mlir::rock"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the KernelType of this op + }], + /*retType=*/"::mlir::rock::KernelType", + /*methodName=*/"getKernelType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the arch string of this op + }], + /*retType=*/"StringRef", + /*methodName=*/"getArch", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the OpOperand that corresponds to the operand argument + that corresponds to the output result of the operation. + }], + /*retType=*/"OpOperand *", + /*methodName=*/"getOutArgument", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the size of the matrix multiplication that this op will eventually + perform. + }], + /*retType=*/"::mlir::rock::GemmGemmSize", + /*methodName=*/"getGemmGemmSize", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the element type of [what will become] matrix A for this operation. + }], + /*retType=*/"::mlir::Type", + /*methodName=*/"getAType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the element type of [what will become] matrix B for this operation. + }], + /*retType=*/"::mlir::Type", + /*methodName=*/"getBType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the element type of [what will become] matrix C for this operation. + }], + /*retType=*/"::mlir::Type", + /*methodName=*/"getCType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the element type of [what will become] output matrix for this operation. + }], + /*retType=*/"::mlir::Type", + /*methodName=*/"getOutType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the whether matrix A is transposed. + }], + /*retType=*/"bool", + /*methodName=*/"getTransposedA", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the whether matrix B is transposed. + }], + /*retType=*/"bool", + /*methodName=*/"getTransposedB", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the whether matrix C is transposed. + }], + /*retType=*/"bool", + /*methodName=*/"getTransposedC", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the whether output matrix is transposed. + }], + /*retType=*/"bool", + /*methodName=*/"getTransposedOut", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + InterfaceMethod< + /*desc=*/[{ + Return the features attribute of this op. + }], + /*retType=*/"::mlir::rock::GemmFeatures", + /*methodName=*/"getGemmFeatures", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getFeatures(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the optional number of Compute Units the GPU provides. + }], + /*retType=*/"std::optional", + /*methodName=*/"getNumCU", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ "" + >, + + InterfaceMethod< + /*desc=*/[{ + Set the tuning parameters attribute of the first GEMM + }], + /*retType=*/"void", + /*methodName=*/"setGemm0ParamsAttr", + /*args=*/(ins "::mlir::Attribute":$params), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + $_op->setAttr($_op.getParams0AttrName(), params); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Set the tuning parameters attribute of the second GEMM + }], + /*retType=*/"void", + /*methodName=*/"setGemm1ParamsAttr", + /*args=*/(ins "::mlir::Attribute":$params), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + $_op->setAttr($_op.getParams1AttrName(), params); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Get the tuning parameters attribute of the first GEMM + }], + /*retType=*/"std::optional", + /*methodName=*/"getGemm0Params", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getParams0(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Get the tuning parameters attribute of the second GEMM + }], + /*retType=*/"std::optional", + /*methodName=*/"getGemm1Params", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getParams1(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the index of the elementwise region argument that comes from the first GEMM. + }], + /*retType=*/"uint32_t", + /*methodName=*/"getFirstGemmIndex", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/ "" + >, + + // TODO: more methods here as needed + ]; + + let verify = [{ + auto concreteOp = ::mlir::cast($_op); + if ($_op->getNumResults() == 1) { + if ($_op->getResult(0).getType() != + concreteOp.getOutArgument()->get().getType()) { + return $_op->emitOpError("result type must match output argument type"); + } + } + return ::mlir::success(); + }]; +} + +#endif // ROCK_GEMM_GEMM_WRAPPER_INTERFACE diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 3fc9b6f33d0a..a578ff8b3b3f 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/Rock/IR/RockAttrDefs.td" include "mlir/Dialect/Rock/IR/RockConvInterface.td" +include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td" include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td" include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.td" include "mlir/Dialect/Rock/IR/RockWriterOpInterface.td" @@ -204,39 +205,44 @@ def Rock_ReduceOp : ::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperand(1); } }]; } - -def Rock_AttentionOp : - Rock_Op<"attention", [DeclareOpInterfaceMethods, RockFusionRoot, AttrSizedOperandSegments]>, - Arguments<(ins - TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries, - TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys, - TensorOrMemRefOf<[F32, F16, BF16]>:$values, - Variadic:$preSoftmaxElemWiseInputs, - Optional>:$currentSeqLen, - TensorOrMemRefOf<[F32, BF16, F16]>:$out, - UnitAttr:$qTransposed, - UnitAttr:$kTransposed, - UnitAttr:$vTransposed, - UnitAttr:$oTransposed, - StrAttr:$arch, - Rock_GemmFeaturesAttr:$features, - OptionalAttr:$numCU, - OptionalAttr:$params0, - OptionalAttr:$params1, - I32Attr:$firstGemmIdx - )>, - Results<(outs Optional>:$result)> { +def Rock_AttentionOp + : Rock_Op< + "attention", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + RockFusionRoot, AttrSizedOperandSegments]>, + Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries, + TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys, + TensorOrMemRefOf<[F32, F16, BF16]>:$values, + Variadic:$preSoftmaxElemWiseInputs, + Optional>:$currentSeqLen, + TensorOrMemRefOf<[F32, F16, BF16]>:$out, UnitAttr:$qTransposed, + UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed, + StrAttr:$arch, Rock_GemmFeaturesAttr:$features, + OptionalAttr:$numCU, + OptionalAttr:$params0, + OptionalAttr:$params1, + I32Attr:$firstGemmIdx)>, + Results<(outs Optional>:$result)> { let summary = "Attention operation of transformer models"; let description = [{ - Performs the operation out = SOFTMAX((queries * keys) .* scale) * values. + Performs the operation out = SOFTMAX(preSoftmaxBody(queries * keys, preSoftmaxElemWiseInputs)) * values. - This operation performs attention mechanism of transformer models. + This operation performs attention mechanism of transformer models. There is an optional element-wise + fusion just before the softmax, defined by `preSoftmaxBody` with inputs `preSoftmaxElemWiseInputs`. + + If none of the `transposed` attributes are set, then `queries` is [G] x seq_q x head_qk, + `keys` is [G] x head_qk x seq_k, `values` is [G] x seq_k x head_v and `out` is [G] x seq_q x head_v, + where G is the optional group dimension (which is assumed to be 1 if not set). + + The transpose attributes allow for the non-group dimensions of the matrix to be + transposed. For example, if `qTransposed` is set, then the argument `queries` should be + a [G] x head_qk x seq_q memory. Those creating a `rock.attention` must specify the GPU architecture being targetted and the number of compute units (numCu) available. The parameters `gridSize`, and `blockSize` are optional as they can be inferred by a tuning process or a heuristic, but they must be set before the `attention` is - lowered into the `gridwise_attention` stage of the code generation pipeline. + lowered into the `gridwise_attention_accel` stage of the code generation pipeline. `features` specifies what hardware features can be used in the generated code. }]; @@ -250,8 +256,56 @@ def Rock_AttentionOp : (`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $vTransposed^)? $values `:` type($values) `->` type($out) `\n` `}` attr-dict (`->` type($result)^)? }]; - let extraClassDeclaration = [{ - ::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); } +} + +def Rock_GemmElementwiseGemmOp + : Rock_Op<"gemm_elementwise_gemm", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + RockFusionRoot]>, + AllElementTypesMatch<["a", "b", "c"]>, + Arguments<(ins TensorOrMemRefOf<[F32]>:$a, TensorOrMemRefOf<[F32]>:$b, + TensorOrMemRefOf<[F32]>:$c, + Variadic:$elemwiseInputs, + TensorOrMemRefOf<[F32]>:$out, UnitAttr:$aTransposed, + UnitAttr:$bTransposed, UnitAttr:$cTransposed, UnitAttr:$oTransposed, + StrAttr:$arch, Rock_GemmFeaturesAttr:$features, + OptionalAttr:$numCU, + OptionalAttr:$params0, + OptionalAttr:$params1, + I32Attr:$firstGemmIdx)>, + Results<(outs Optional>:$result)> { + let summary = "GEMM-elementwise-GEMM operation"; + let description = [{ + Performs the operation out = preSecondGemmBody(a * b, elemwiseInputs) * c. + + This operation performs fused GEMM-elementwise-GEMM. There is an optional element-wise + fusion just before the second GEMM, defined by `preSecondGemmBody` with inputs `elemwiseInputs`. + + If none of the `transposed` attributes are set, then `a` is [G] x M x K, + `b` is [G] x K x N, `c` is [G] x N x O and `out` is [G] x M x O, where G is the + optional group dimension (which is assumed to be 1 if not set). + + The transpose attributes allow for the non-group dimensions of the matrix to be + transposed. For example, if `aTransposed` is set, then the argument `a` should be + a [G] x K x M memory. + + Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted + and the number of compute units (numCu) available. The parameters + `gridSize`, and `blockSize` are optional as they can be inferred by + a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is + lowered into the `gridwise_attention_accel` stage of the code generation pipeline. + + `features` specifies what hardware features can be used in the generated code. + }]; + let hasVerifier = 1; + let regions = (region AnyRegion:$preSecondGemmBody); + let assemblyFormat = [{ + `{` `\n` + ` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n` + (`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)? + (`tr` $oTransposed^)? $out `=` `ab` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n` + `}` attr-dict (`->` type($result)^)? }]; } @@ -432,24 +486,23 @@ def Rock_GridwiseGemmAccelOp : } // gridwise_attention_accel -def Rock_GridwiseAttentionAccelOp : - Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods, RockFusionRoot, AttrSizedOperandSegments]>, - Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries, - MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys, - MemRefRankOf<[F32, F16, BF16,], [3]>:$values, - Variadic:$preSoftmaxElemWiseInputs, - Optional>:$currentSeqLen, - MemRefRankOf<[F32, F16, BF16], [3]>:$out, - StrAttr:$arch, - Rock_GemmFeaturesAttr:$features, - I32Attr:$blockSize, - I32Attr:$gridSize, - UnitAttr:$disableQBypassLDS, - OptionalAttr:$prePadG0M, - OptionalAttr:$prePadG0N, - RockAccelTuningParamAttrInterface:$params0, - RockAccelTuningParamAttrInterface:$params1, - I32Attr:$firstGemmIdx)> { +def Rock_GridwiseAttentionAccelOp + : Rock_Op<"gridwise_attention_accel", + [DeclareOpInterfaceMethods, + RockFusionRoot, AttrSizedOperandSegments]>, + Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries, + MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys, + MemRefRankOf<[F32, F16, BF16], [3]>:$values, + Variadic:$preSoftmaxElemWiseInputs, + Optional>:$currentSeqLen, + MemRefRankOf<[F32, F16, BF16], [3]>:$out, StrAttr:$arch, + Rock_GemmFeaturesAttr:$features, I32Attr:$blockSize, + I32Attr:$gridSize, UnitAttr:$disableQBypassLDS, + OptionalAttr:$prePadG0M, + OptionalAttr:$prePadG0N, + RockAccelTuningParamAttrInterface:$params0, + RockAccelTuningParamAttrInterface:$params1, I32Attr:$firstGemmIdx, + DefaultValuedOptionalAttr:$enableSoftmax)> { let summary = "Gridwise attention accelerated version"; let description = [{ The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration. diff --git a/mlir/lib/Dialect/Rock/IR/CMakeLists.txt b/mlir/lib/Dialect/Rock/IR/CMakeLists.txt index 64d44c8a8bbe..d5a6d91379fc 100644 --- a/mlir/lib/Dialect/Rock/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_rocmlir_dialect_library(MLIRRockOps TransformMapBuilder.cpp RockDialect.cpp RockGemmWrapperInterface.cpp + RockGemmGemmWrapperInterface.cpp RockConvInterface.cpp RockTuningParamAttrInterface.cpp RockAccelTuningParamAttrInterface.cpp @@ -17,6 +18,7 @@ add_rocmlir_dialect_library(MLIRRockOps DEPENDS MLIRRockAttrDefsIncGen MLIRRockGemmWrapperInterfaceIncGen + MLIRRockGemmGemmWrapperInterfaceIncGen MLIRRockTuningParamAttrInterfaceIncGen MLIRRockAccelTuningParamAttrInterfaceIncGen MLIRRockAcceptingViewOpInterfaceIncGen diff --git a/mlir/lib/Dialect/Rock/IR/RockAcceptingViewOpInterface.cpp b/mlir/lib/Dialect/Rock/IR/RockAcceptingViewOpInterface.cpp index d58b9aeba86a..69399f21e3f3 100644 --- a/mlir/lib/Dialect/Rock/IR/RockAcceptingViewOpInterface.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockAcceptingViewOpInterface.cpp @@ -1,4 +1,4 @@ -//===- RockGemmWrapperInterface.cpp - -------===// +//===- RockAcceptingViewOpInterface.cpp - -------===// // // Part of the rocMLIR Project, under the Apache License v2.0 with LLVM // Exceptions. See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Rock/IR/RockConvInterface.cpp b/mlir/lib/Dialect/Rock/IR/RockConvInterface.cpp index bc99eea8924a..c45a0b898d3a 100644 --- a/mlir/lib/Dialect/Rock/IR/RockConvInterface.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockConvInterface.cpp @@ -1,4 +1,4 @@ -//===- RockGemmWrapperInterface.cpp - ops that wrap rock.gemm -------===// +//===- RockConvInterface.cpp - ops that wrap rock.gemm -------===// // // Part of the rocMLIR Project, under the Apache License v2.0 with LLVM // Exceptions. See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 4278708f2863..76a46a3b2cf7 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockTypes.h" #include "mlir/Dialect/Rock/utility/math.h" @@ -33,6 +34,7 @@ #include "mlir/IR/TypeRange.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -46,6 +48,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/SMLoc.h" #include @@ -486,10 +489,13 @@ ConvOpType mlir::rock::convOpTypeFromKernelType(KernelType kernelType) { return ConvOpType::BwdWeight; case KernelType::Gemm: llvm_unreachable( - "Gemm ops shouldn't be in convolution-specific lowering passes"); + "GEMM ops shouldn't be in convolution-specific lowering passes"); case KernelType::Attention: llvm_unreachable( "Attention ops shouldn't be in convolution-specific lowering passes"); + case KernelType::GemmElementwiseGemm: + llvm_unreachable( + "gemm+gemm ops shouldn't be in convolution-specific lowering passes"); } llvm_unreachable("Unsuppported KernelType"); } @@ -566,17 +572,20 @@ static LogicalResult verifyGemmTypes(Operation *op, GemmFeatures features, "Mfma gridwise does not support E4M3/E5M2 data types "); } } - if (isa(elemTypeA) && !isa(elemTypeC)) { - return op->emitOpError("floating-point input type ") - << elemTypeA - << " requires a floating-point output type, but the output type is " - << elemTypeC; - } - if (isa(elemTypeA) && !isa(elemTypeC)) { - return op->emitOpError("integer input type ") - << elemTypeA - << " requires an integer output type, but the output type is " - << elemTypeC; + if (elemTypeC) { + if (isa(elemTypeA) && !isa(elemTypeC)) { + return op->emitOpError("floating-point input type ") + << elemTypeA + << " requires a floating-point output type, but the output type " + "is " + << elemTypeC; + } + if (isa(elemTypeA) && !isa(elemTypeC)) { + return op->emitOpError("integer input type ") + << elemTypeA + << " requires an integer output type, but the output type is " + << elemTypeC; + } } return success(); } @@ -2068,77 +2077,188 @@ LogicalResult BlockwiseFillOp::verify() { } //===-----------------------------------------------------===// -// AttentionOp +// GemmElementwiseGemmOp //===-----------------------------------------------------===// -LogicalResult AttentionOp::verify() { - ShapedType qType = getQueries().getType(); +OpOperand *GemmElementwiseGemmOp::getOutArgument() { + return &(*this)->getOpOperand(getNumOperands() - 1); +} + +Type GemmElementwiseGemmOp::getOutType() { return getOut().getType(); } + +Type GemmElementwiseGemmOp::getAType() { return getA().getType(); } + +Type GemmElementwiseGemmOp::getBType() { return getB().getType(); } + +Type GemmElementwiseGemmOp::getCType() { return getC().getType(); } + +bool GemmElementwiseGemmOp::getTransposedA() { return getATransposed(); } + +bool GemmElementwiseGemmOp::getTransposedB() { return getBTransposed(); } + +bool GemmElementwiseGemmOp::getTransposedC() { return getCTransposed(); } + +bool GemmElementwiseGemmOp::getTransposedOut() { return getOTransposed(); } + +KernelType GemmElementwiseGemmOp::getKernelType() { + return KernelType::GemmElementwiseGemm; +} + +uint32_t GemmElementwiseGemmOp::getFirstGemmIndex() { + return getFirstGemmIdx(); +} + +GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() { + ShapedType typeA = getA().getType(), typeB = getB().getType(), + typeC = getC().getType(); + ArrayRef dimsA = typeA.getShape(), dimsB = typeB.getShape(), + dimsC = typeC.getShape(); + int64_t offsetA = dimsA.size() == 2 ? 0 : 1, + offsetB = dimsB.size() == 2 ? 0 : 1, + offsetC = dimsC.size() == 2 ? 0 : 1; + int64_t g = offsetA ? dimsA[0] : 1, + m = dimsA[offsetA + (getATransposed() ? 1 : 0)], + k = dimsA[offsetA + (getATransposed() ? 0 : 1)], + n = dimsB[offsetB + (getBTransposed() ? 0 : 1)], + o = dimsC[offsetC + (getCTransposed() ? 1 : 0)]; + return GemmGemmSize(g, m, k, n, o); +} + +static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op, + Value currentSeqLen) { + ShapedType qType = cast(op.getAType()); int64_t qBatchDim = qType.getShape().size() == 3 ? qType.getShape()[0] : 1; ArrayRef qLastDims = qType.getShape().slice(qType.getRank() - 2); - auto [queryM, queryK] = getQTransposed() + auto [queryM, queryK] = op.getTransposedA() ? std::tuple{qLastDims[1], qLastDims[0]} : std::tuple{qLastDims[0], qLastDims[1]}; - ShapedType kType = getKeys().getType(); + ShapedType kType = cast(op.getBType()); int64_t kBatchDim = kType.getShape().size() == 3 ? kType.getShape()[0] : 1; ArrayRef kLastDims = kType.getShape().slice(kType.getRank() - 2); - auto [keyK, keyN] = getKTransposed() ? std::tuple{kLastDims[1], kLastDims[0]} - : std::tuple{kLastDims[0], kLastDims[1]}; + auto [keyK, keyN] = op.getTransposedB() + ? std::tuple{kLastDims[1], kLastDims[0]} + : std::tuple{kLastDims[0], kLastDims[1]}; - ShapedType vType = getValues().getType(); + ShapedType vType = cast(op.getCType()); int64_t vBatchDim = vType.getShape().size() == 3 ? vType.getShape()[0] : 1; ArrayRef vLastDims = vType.getShape().slice(vType.getRank() - 2); - auto [valueK, valueN] = getVTransposed() + auto [valueK, valueN] = op.getTransposedC() ? std::tuple{vLastDims[1], vLastDims[0]} : std::tuple{vLastDims[0], vLastDims[1]}; if (qBatchDim != kBatchDim || kBatchDim != vBatchDim) { - return emitError("Batch dimensions do not match"); + return op.emitError("Batch dimensions do not match"); } if (queryK != keyK) { - return emitError("reduction dimensions of first gemm do not match"); + return op.emitError("reduction dimensions of first gemm do not match"); } if (keyN != valueK) { - return emitError("reduction dimensions of second gemm do not match"); + return op.emitError("reduction dimensions of second gemm do not match"); } // check output type - ShapedType oType = getOut().getType(); + ShapedType oType = cast(op.getOutType()); int64_t oBatchDim = oType.getShape().size() == 3 ? oType.getShape()[0] : 1; ArrayRef oLastDims = oType.getShape().slice(oType.getRank() - 2); auto [outputSeqLen, outputHeadDim] = - getOTransposed() ? std::tuple{oLastDims[1], oLastDims[0]} - : std::tuple{oLastDims[0], oLastDims[1]}; + op.getTransposedOut() ? std::tuple{oLastDims[1], oLastDims[0]} + : std::tuple{oLastDims[0], oLastDims[1]}; if (qType.getShape().size() != oType.getShape().size()) { - return emitError("Number of dimensions do not match (Q and Output)"); + return op.emitError("Number of dimensions do not match (Q and Output)"); } if (qBatchDim != oBatchDim) { - return emitError("Batch dimensions do not match (Q and Output)"); + return op.emitError("Batch dimensions do not match (Q and Output)"); } if (queryM != outputSeqLen) { - return emitError("Sequence length does not match (Q and Output)"); + return op.emitError("Sequence length does not match (Q and Output)"); } if (valueN != outputHeadDim) { - return emitError("Head dimensions do not match (V and Output)"); + return op.emitError("Head dimensions do not match (V and Output)"); } // check currentSeqLen (KV Cache) - auto currentSeqLen = getCurrentSeqLen(); if (currentSeqLen) { - ShapedType seqLenType = currentSeqLen.getType(); + ShapedType seqLenType = cast(currentSeqLen.getType()); if (seqLenType.getShape().size() != 1) { - return emitError("Number of dimensions is not one (currentSeqLen)"); + return op.emitError("Number of dimensions is not one (currentSeqLen)"); } if (seqLenType.getShape()[0] != oBatchDim) { - return emitError( + return op.emitError( "Batch dimensions do not match (currentSeqLen and Output)"); } } return success(); } +LogicalResult GemmElementwiseGemmOp::verify() { + return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr); +} + +void GemmElementwiseGemmOp::getEffects( + SmallVectorImpl &effects) { + auto *read = MemoryEffects::Read::get(); + auto *write = MemoryEffects::Write::get(); + effects.emplace_back(read, &getOutMutable()); + effects.emplace_back(write, &getOutMutable()); + + effects.emplace_back(read, &getAMutable()); + effects.emplace_back(read, &getBMutable()); + effects.emplace_back(read, &getCMutable()); + for (auto ®ionArg : getElemwiseInputsMutable()) + effects.emplace_back(read, ®ionArg); +} + +//===-----------------------------------------------------===// +// AttentionOp +//===-----------------------------------------------------===// + +OpOperand *AttentionOp::getOutArgument() { + return &(*this)->getOpOperand(getNumOperands() - 1); +} + +Type AttentionOp::getOutType() { return getOut().getType(); } + +Type AttentionOp::getAType() { return getQueries().getType(); } + +Type AttentionOp::getBType() { return getKeys().getType(); } + +Type AttentionOp::getCType() { return getValues().getType(); } + +bool AttentionOp::getTransposedA() { return getQTransposed(); } + +bool AttentionOp::getTransposedB() { return getKTransposed(); } + +bool AttentionOp::getTransposedC() { return getVTransposed(); } + +bool AttentionOp::getTransposedOut() { return getOTransposed(); } + +KernelType AttentionOp::getKernelType() { return KernelType::Attention; } + +uint32_t AttentionOp::getFirstGemmIndex() { return getFirstGemmIdx(); } + +GemmGemmSize AttentionOp::getGemmGemmSize() { + ShapedType typeA = getQueries().getType(), typeB = getKeys().getType(), + typeC = getValues().getType(); + ArrayRef dimsA = typeA.getShape(), dimsB = typeB.getShape(), + dimsC = typeC.getShape(); + int64_t offsetA = dimsA.size() == 2 ? 0 : 1, + offsetB = dimsB.size() == 2 ? 0 : 1, + offsetC = dimsC.size() == 2 ? 0 : 1; + int64_t g = offsetA ? dimsA[0] : 1, + m = dimsA[offsetA + (getQTransposed() ? 1 : 0)], + k = dimsA[offsetA + (getQTransposed() ? 0 : 1)], + n = dimsB[offsetB + (getKTransposed() ? 0 : 1)], + o = dimsC[offsetC + (getVTransposed() ? 1 : 0)]; + return GemmGemmSize(g, m, k, n, o); +} + +LogicalResult AttentionOp::verify() { + return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen()); +} + void AttentionOp::getEffects( SmallVectorImpl &effects) { auto *read = MemoryEffects::Read::get(); diff --git a/mlir/lib/Dialect/Rock/IR/RockGemmGemmWrapperInterface.cpp b/mlir/lib/Dialect/Rock/IR/RockGemmGemmWrapperInterface.cpp new file mode 100644 index 000000000000..f0d9af4a2026 --- /dev/null +++ b/mlir/lib/Dialect/Rock/IR/RockGemmGemmWrapperInterface.cpp @@ -0,0 +1,22 @@ +//===- RockGemmGemmWrapperInterface.cpp - ops that wrap rock.attention +//-------===// +// +// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (c) 2025 Advanced Micro Devices INc. +//===----------------------------------------------------------------------===// +// +// This file defines RockGemmGemmWrapperInterface, which abstracts attention and +// gemm+gemm to allow code to operate on them generically. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Rock/IR/Rock.h" + +namespace mlir { +namespace rock { +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.cpp.inc" +} // namespace rock +} // namespace mlir diff --git a/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp b/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp index 5f71d499029c..e84eede62ce8 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Rock/IR/GemmSize.h" #include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h" #include "mlir/Dialect/Rock/Passes.h" #include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h" @@ -14,6 +15,7 @@ #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" namespace mlir { @@ -40,7 +42,7 @@ struct AffixTuningParameters private: // Actual implementation. void affixTuningParametersImpl(RockGemmWrapperInterface op); - void affixTuningParametersImpl(AttentionOp op); + void affixTuningParametersImpl(RockGemmGemmWrapperInterface op); template void setUtilityKernelSizes(Value arg, T utilityOp); @@ -52,7 +54,8 @@ void AffixTuningParameters::runOnOperation() { func.walk( [&](RockGemmWrapperInterface op) { affixTuningParametersImpl(op); }); - func.walk([&](AttentionOp op) { affixTuningParametersImpl(op); }); + func.walk( + [&](RockGemmGemmWrapperInterface op) { affixTuningParametersImpl(op); }); func.walk([&](ReduceOp op) { func::FuncOp funcOp = getOperation(); if (!funcOp->hasAttr("block_size")) { @@ -208,15 +211,16 @@ void AffixTuningParameters::affixTuningParametersImpl( b.getI32IntegerAttr(validParams.blockSize)); } } + static RockAccelTuningParamAttrInterface -deriveGemm1TuningParams(OpBuilder &builder, AttentionOp op, +deriveGemm1TuningParams(OpBuilder &builder, RockGemmGemmWrapperInterface op, AttnPerfConfigAttr attnPerfConfig) { auto gemm0TuningParams = - cast(op.getParams0().value()); + cast(op.getGemm0Params().value()); int64_t gemm1KPack = gemm0TuningParams.getKpack(); int64_t gemmNPerWaveOrMnPerXdl = gemm0TuningParams.getNPerWave(); if (auto gemm0XdlDerivedParams = - dyn_cast(op.getParams0().value())) { + dyn_cast(op.getGemm0Params().value())) { gemmNPerWaveOrMnPerXdl = gemm0XdlDerivedParams.getMnPerXdl(); return XdlopsGemmDerivedParamsAttr::get( builder.getContext(), gemm0TuningParams.getMPerBlock() / gemm1KPack, @@ -240,18 +244,17 @@ deriveGemm1TuningParams(OpBuilder &builder, AttentionOp op, gemm0TuningParams.getOutputSwizzle(), gemm0TuningParams.getForceUnroll()); } -void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { +void AffixTuningParameters::affixTuningParametersImpl( + RockGemmGemmWrapperInterface op) { OpBuilder builder(op.getContext()); - Type elemTypeQ = cast(op.getQueries().getType()).getElementType(); - Type elemTypeK = cast(op.getKeys().getType()).getElementType(); - Type elemTypeV = cast(op.getValues().getType()).getElementType(); - bool isAccel = rock::isAccel(op.getFeatures()); + bool isAccel = rock::isAccel(op.getGemmFeatures()); if (!isAccel) { - op.emitError("Currently, attention op is only supported on GPUs " + op.emitError("Currently, attention/gemm+gemm op is only " + "supported on GPUs " "with matrix accelerator extentions"); return signalPassFailure(); } - Attribute params0 = op.getParams0().value_or(nullptr); + Attribute params0 = op.getGemm0Params().value_or(nullptr); // set a default one if params is not provided StringAttr perfConfigStrAttr = builder.getStringAttr("attn:v1:32,32,32,32,32,32,1,1"); @@ -266,7 +269,7 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { op.emitError("perf config string has an incorrect format."); return signalPassFailure(); } - GemmFeatures features = op.getFeatures(); + GemmFeatures features = op.getGemmFeatures(); RockAccelTuningParamAttrInterface accelParams0; if (bitEnumContainsAny(features, GemmFeatures::mfma)) { auto xdlopsParams0 = XdlopsGemmParamsAttr::get( @@ -284,7 +287,7 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { attnPerfConfig.getMnPerXdl(), 1, attnPerfConfig.getScheduleVersion(), 2, attnPerfConfig.getForceUnroll()); } - op.setParams0Attr(accelParams0); + op.setGemm0ParamsAttr(accelParams0); if (attnPerfConfig.getMPerBlockG0() > attnPerfConfig.getMPerBlockG1()) { op.emitError( "The MPerBlockG0 should be larger or equal to getMPerBlockG1."); @@ -292,8 +295,8 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { } RockAccelTuningParamAttrInterface accelParams1 = deriveGemm1TuningParams(builder, op, attnPerfConfig); - op.setParams1Attr(accelParams1); - int64_t waveSize = rock::lookupArchInfo(op.getArchAttr()).waveSize; + op.setGemm1ParamsAttr(accelParams1); + int64_t waveSize = rock::lookupArchInfo(op.getArch()).waveSize; int64_t blockSize = waveSize * accelParams0.getNPerBlock() * accelParams0.getMPerBlock() / (accelParams0.getMPerWave() * accelParams0.getNPerWave()); @@ -302,12 +305,14 @@ void AffixTuningParameters::affixTuningParametersImpl(AttentionOp op) { LLVM_DEBUG(llvm::dbgs() << "accelParams1=" << accelParams1 << "\n"); LogicalResult isValidBlockwiseGemm0 = populateParamsAccelPtr->isValidBlockwiseGemm( - accelParams0, elemTypeQ, elemTypeK, op.getArch(), + accelParams0, cast(op.getAType()).getElementType(), + cast(op.getBType()).getElementType(), op.getArch(), /*enableBlockSizeUpperLimit=*/false, /*enableDPerWaveFiltering=*/false); LogicalResult isValidBlockwiseGemm1 = populateParamsAccelPtr->isValidBlockwiseGemm( - accelParams1, elemTypeV, elemTypeV, op.getArch(), + accelParams1, cast(op.getCType()).getElementType(), + cast(op.getCType()).getElementType(), op.getArch(), /*enableBlockSizeUpperLimit=*/false, /*enableDPerWaveFiltering=*/false); if (isValidBlockwiseGemm0.failed() || isValidBlockwiseGemm1.failed()) { diff --git a/mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp index d7f5d292d459..9fde520d4ada 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp @@ -227,6 +227,8 @@ void mlir::rock::registerBufferizableOpInterfaceExternalModels( ConvertingCopyKernelOp::attachInterface< GemmLikeInterface>(*ctx); AttentionOp::attachInterface>(*ctx); + GemmElementwiseGemmOp::attachInterface< + GemmLikeInterface>(*ctx); TransformOp::attachInterface(*ctx); TensorUntransformCastOp::attachInterface( diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index 4c0d1f145a7c..1cdb1f9834b1 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/MHAL/IR/MHAL.h" #include "mlir/Dialect/Rock/IR/GemmSize.h" #include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockTypes.h" #include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" #include "mlir/Dialect/Rock/Passes.h" @@ -43,6 +44,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include #include #include @@ -84,6 +86,18 @@ struct GemmRewritePattern : public OpConversionPattern { const BufferDependencyAnalysis &bufferDeps; }; +struct GemmElementwiseGemmRewritePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(GemmElementwiseGemmOp op, + GemmElementwiseGemmOpAdaptor adaptor, + ConversionPatternRewriter &rw) const override; + + LogicalResult computeGridSize(ConversionPatternRewriter &rw, + GemmElementwiseGemmOp op, Value a, Value b, + Value c) const; +}; + struct AttentionRewritePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AttentionOp op, AttentionOpAdaptor adaptor, @@ -93,6 +107,135 @@ struct AttentionRewritePattern : public OpConversionPattern { Value queries, Value keys, Value values) const; }; +template +static LogicalResult +computeGridSizeAttentionGemmElmtGemm(ConversionPatternRewriter &rw, Op op, + Value a, Value b, Value c) { + RockAccelTuningParamAttrInterface accelParams0 = + cast(op.getGemm0Params().value()); + + SmallVector aShape = + llvm::to_vector<3>(cast(a.getType()).getShape()); + + SmallVector bShape = + llvm::to_vector<3>(cast(b.getType()).getShape()); + + SmallVector cShape = + llvm::to_vector<3>(cast(c.getType()).getShape()); + + GemmSize gemm0Size(/*g=*/aShape[0], /*m=*/bShape[2], + /*k=*/aShape[1], + /*n=*/aShape[2]); + + int64_t gridSize = + ((gemm0Size.n) / accelParams0.getNPerBlock()) * gemm0Size.g; + + IntegerAttr gridSizeAttr = rw.getI32IntegerAttr(gridSize); + func::FuncOp funcOp = cast(op->getParentOp()); + funcOp->setAttr("grid_size", gridSizeAttr); + return success(); +} + +static LogicalResult +commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw, + RockGemmGemmWrapperInterface op, Value a, Value b, + Value c, Value out, Value currentSeqLen, + ValueRange elementwiseInputs, + Region &preSecondOpRegion, bool enableSoftmax) { + Location loc = op->getLoc(); + + if (!isa(op.getAType())) + return op.emitOpError("Cannot lower unbufferized gemm to gridwise"); + + bool isAccel = rock::isAccel(op.getGemmFeatures()); + if (!isAccel) { + return op.emitError("Currently, op is only supported on GPUs " + "with matrix accelerator extentions"); + } + if (!op.getGemm0Params().has_value()) { + return op.emitError("gemm0 params is missing and it should've been " + "assigned by affix-tuing-params"); + } + RockAccelTuningParamAttrInterface params0 = + cast(op.getGemm0Params().value()); + if (!op.getGemm1Params().has_value()) { + return op.emitError("gemm1 params is missing and it should've been " + "assigned by affix-tuing-params"); + } + RockAccelTuningParamAttrInterface params1 = + cast(op.getGemm1Params().value()); + + // Note: the gridwise ops take K x M and K x N, so A must be transposed if + // it's in the natural M x K form + a = normalizeMatrix(a, rw, loc, !op.getTransposedA(), "gemm0K", "gemm0M"); + b = normalizeMatrix(b, rw, loc, op.getTransposedB(), "gemm0K", "gemm0N"); + c = normalizeMatrix(c, rw, loc, op.getTransposedC(), "gemm1K", "gemm1N"); + out = + normalizeMatrix(out, rw, loc, op.getTransposedOut(), "gemm1M", "gemm1N"); + + // Note, matrix dimension correctness is handled in the verifier + ArrayRef aShape = cast(a.getType()).getShape(); + ArrayRef bShape = cast(b.getType()).getShape(); + ArrayRef cShape = cast(c.getType()).getShape(); + GemmSize gemm0Size(/*g=*/aShape[0], /*m=*/bShape[2], + /*k=*/aShape[1], + /*n=*/aShape[2]); + GemmSize gemm0ExtraPad = + requiredPadding(params0, gemm0Size).value_or(GemmSize{0, 0, 0, 0}); + GemmSize gemm1Size(/*g=*/aShape[0], /*m=*/cShape[2], + /*k=*/cShape[1], + /*n=*/aShape[2]); + GemmSize gemm1ExtraPad = + requiredPadding(params1, gemm1Size).value_or(GemmSize{0, 0, 0, 0}); + + a = padMatrix(a, rw, loc, "gemm0K", gemm0ExtraPad.k, "gemm0N", + gemm0ExtraPad.n); + b = padMatrix(b, rw, loc, "gemm0K", gemm0ExtraPad.k, "gemm0M", + gemm0ExtraPad.m); + c = padMatrix(c, rw, loc, "gemm1K", gemm1ExtraPad.k, "gemm1M", + gemm1ExtraPad.m); + // In the transposed layout, from a tuning params point of view + // the output dimensions are swapped. Though we will only be + // swapping them inside gridwise lowering to keep the surrounding + // fusions legit. So the extra pad needs to be swapped and applied. + out = padMatrix(out, rw, loc, "gemm1N", gemm1ExtraPad.n, "gemm1M", + gemm1ExtraPad.m); + + if (failed(computeGridSizeAttentionGemmElmtGemm(rw, op, a, b, c))) { + return op.emitError("failed to compute the grid size of " + "`GemmElementwiseGemmOp`/`AttentionOp`"); + } + + func::FuncOp func = op->template getParentOfType(); + IntegerAttr blockSizeAttr = cast(func->getAttr("block_size")); + IntegerAttr gridSizeAttr = cast(func->getAttr("grid_size")); + IntegerAttr prePadG0MAttr; + if (gemm0ExtraPad.m) { + prePadG0MAttr = rw.getIndexAttr(gemm0Size.m); + } + IntegerAttr prePadG0NAttr; + if (gemm0ExtraPad.n) { + prePadG0NAttr = rw.getIndexAttr(gemm0Size.n); + } + auto newOp = rw.create( + loc, a, b, c, elementwiseInputs, currentSeqLen, out, + rw.getStringAttr(op.getArch()), + rw.getAttr(op.getGemmFeatures()), blockSizeAttr, + gridSizeAttr, + /*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, params0, + params1, rw.getI32IntegerAttr(op.getFirstGemmIndex()), + rw.getBoolAttr(enableSoftmax)); + bool linalgOpFound = false; + preSecondOpRegion.walk( + [&](linalg::GenericOp genOp) { linalgOpFound = true; }); + if (linalgOpFound) { + rw.inlineRegionBefore(preSecondOpRegion, newOp.getPreSoftmaxBody(), + newOp.getPreSoftmaxBody().begin()); + } + rw.replaceOp(op, newOp); + return success(); +} + static Type getSmallestType(Type type1, Type type2) { return (type1.getIntOrFloatBitWidth() > type2.getIntOrFloatBitWidth()) ? type2 @@ -440,143 +583,41 @@ LogicalResult AttentionRewritePattern::matchAndRewrite(AttentionOp op, AttentionOpAdaptor adaptor, ConversionPatternRewriter &rw) const { - Location loc = op->getLoc(); - - if (!isa(adaptor.getQueries().getType())) - return op.emitOpError("Cannot lower unbufferized gemm to gridwise"); - - bool isAccel = rock::isAccel(op.getFeatures()); - if (!isAccel) { - return op.emitError("Currently, attention op is only supported on GPUs " - "with matrix accelerator extentions"); - } - if (!op.getParams0().has_value()) { - return op.emitError("gemm0 params is missing and it should've been " - "assigned by affix-tuing-params"); - } - RockAccelTuningParamAttrInterface params0 = - cast(op.getParams0Attr()); - if (!op.getParams1().has_value()) { - return op.emitError("gemm1 params is missing and it should've been " - "assigned by affix-tuing-params"); - } - RockAccelTuningParamAttrInterface params1 = - cast(op.getParams1Attr()); - - Value queries = adaptor.getQueries(); - Value keys = adaptor.getKeys(); - Value values = adaptor.getValues(); - Value out = adaptor.getOut(); - - // Note: the gridwise ops take K x M and K x N, so A must be transposed if - // it's in the natural M x K form - queries = normalizeMatrix(queries, rw, loc, !op.getQTransposed(), "gemm0K", - "gemm0M"); - keys = - normalizeMatrix(keys, rw, loc, op.getKTransposed(), "gemm0K", "gemm0N"); - values = - normalizeMatrix(values, rw, loc, op.getVTransposed(), "gemm1K", "gemm1N"); - out = normalizeMatrix(out, rw, loc, op.getOTransposed(), "gemm1M", "gemm1N"); - - // Note, matrix dimension correctness is handled in the verifier - ArrayRef queriesShape = - cast(queries.getType()).getShape(); - ArrayRef keysShape = cast(keys.getType()).getShape(); - ArrayRef valuesShape = cast(values.getType()).getShape(); - GemmSize gemm0Size(/*g=*/queriesShape[0], /*m=*/keysShape[2], - /*k=*/queriesShape[1], - /*n=*/queriesShape[2]); - GemmSize gemm0ExtraPad = - requiredPadding(params0, gemm0Size).value_or(GemmSize{0, 0, 0, 0}); - GemmSize gemm1Size(/*g=*/queriesShape[0], /*m=*/valuesShape[2], - /*k=*/valuesShape[1], - /*n=*/queriesShape[2]); - GemmSize gemm1ExtraPad = - requiredPadding(params1, gemm1Size).value_or(GemmSize{0, 0, 0, 0}); - - queries = padMatrix(queries, rw, loc, "gemm0K", gemm0ExtraPad.k, "gemm0N", - gemm0ExtraPad.n); - keys = padMatrix(keys, rw, loc, "gemm0K", gemm0ExtraPad.k, "gemm0M", - gemm0ExtraPad.m); - values = padMatrix(values, rw, loc, "gemm1K", gemm1ExtraPad.k, "gemm1M", - gemm1ExtraPad.m); - // In the transposed layout, from a tuning params point of view - // the output dimensions are swapped. Though we will only be - // swapping them inside gridwise lowering to keep the surrounding - // fusions legit. So the extra pad needs to be swapped and applied. - out = padMatrix(out, rw, loc, "gemm1N", gemm1ExtraPad.n, "gemm1M", - gemm1ExtraPad.m); - - if (failed(computeGridSize(rw, op, queries, keys, values))) { - return op.emitError("failed to compute the grid size of `AttentionOp`"); - } - - func::FuncOp func = op->getParentOfType(); - IntegerAttr blockSizeAttr = cast(func->getAttr("block_size")); - IntegerAttr gridSizeAttr = cast(func->getAttr("grid_size")); - IntegerAttr prePadG0MAttr; - if (gemm0ExtraPad.m) { - prePadG0MAttr = rw.getIndexAttr(gemm0Size.m); - } - IntegerAttr prePadG0NAttr; - if (gemm0ExtraPad.n) { - prePadG0NAttr = rw.getIndexAttr(gemm0Size.n); - } - auto newOp = rw.create( - loc, queries, keys, values, adaptor.getPreSoftmaxElemWiseInputs(), - op.getCurrentSeqLen(), out, op.getArchAttr(), op.getFeaturesAttr(), - blockSizeAttr, gridSizeAttr, - /*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, params0, - params1, op.getFirstGemmIdxAttr()); - bool linalgOpFound = false; - op.getPreSoftmaxBody().walk( - [&](linalg::GenericOp genOp) { linalgOpFound = true; }); - if (linalgOpFound) { - rw.inlineRegionBefore(op.getPreSoftmaxBody(), newOp.getPreSoftmaxBody(), - newOp.getPreSoftmaxBody().begin()); - } - rw.replaceOp(op, newOp); - return success(); + return commonAttentionGemmElmtGemm( + rw, op, adaptor.getQueries(), adaptor.getKeys(), adaptor.getValues(), + adaptor.getOut(), adaptor.getCurrentSeqLen(), + adaptor.getPreSoftmaxElemWiseInputs(), op.getPreSoftmaxBody(), + /*enableSoftmax=*/true); } LogicalResult AttentionRewritePattern::computeGridSize(ConversionPatternRewriter &rw, AttentionOp op, Value queries, Value keys, Value values) const { + return computeGridSizeAttentionGemmElmtGemm(rw, op, queries, keys, values); +} - RockAccelTuningParamAttrInterface accelParams0 = - cast(op.getParams0Attr()); - - SmallVector queriesShape = - llvm::to_vector<3>(cast(queries.getType()).getShape()); - - SmallVector keysShape = - llvm::to_vector<3>(cast(keys.getType()).getShape()); - - SmallVector valuesShape = - llvm::to_vector<3>(cast(values.getType()).getShape()); - - GemmSize gemm0Size(/*g=*/queriesShape[0], /*m=*/keysShape[2], - /*k=*/queriesShape[1], - /*n=*/queriesShape[2]); - GemmSize gemm1Size(/*g=*/queriesShape[0], /*m=*/valuesShape[2], - /*k=*/valuesShape[1], - /*n=*/queriesShape[2]); - - int64_t gridSize = - ((gemm0Size.n) / accelParams0.getNPerBlock()) * gemm0Size.g; +LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite( + GemmElementwiseGemmOp op, GemmElementwiseGemmOpAdaptor adaptor, + ConversionPatternRewriter &rw) const { + return commonAttentionGemmElmtGemm( + rw, op, adaptor.getA(), adaptor.getB(), adaptor.getC(), adaptor.getOut(), + /*currentSeqLen=*/nullptr, adaptor.getElemwiseInputs(), + op.getPreSecondGemmBody(), /*enableSoftmax=*/false); +} - IntegerAttr gridSizeAttr = rw.getI32IntegerAttr(gridSize); - func::FuncOp funcOp = cast(op->getParentOp()); - funcOp->setAttr("grid_size", gridSizeAttr); - return success(); +LogicalResult GemmElementwiseGemmRewritePattern::computeGridSize( + ConversionPatternRewriter &rw, GemmElementwiseGemmOp op, Value a, Value b, + Value c) const { + return computeGridSizeAttentionGemmElmtGemm(rw, op, a, b, c); } void RockGemmToGridwisePass::runOnOperation() { MLIRContext *ctx = &getContext(); ConversionTarget target(*ctx); - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalOp(ctx, bufferDeps); - patterns.add(ctx); + patterns.add(ctx); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index a52968b6295c..cfb000608825 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1496,11 +1496,11 @@ struct GridwiseAttentionAccelRewritePattern // > invertTr(linalg input to gemmOutput maps) // > (linalgOtherInput to op arg maps) ArrayAttr linalgGridSubTileMaps = gemm0OutViews.gridSubTile; - ArrayAttr GemmOutToLinalgMaps = + ArrayAttr gemmOutToLinalgMaps = invertTransforms(rewriter, loc, linalgToGemmOutMaps); - if (!GemmOutToLinalgMaps.empty()) { + if (!gemmOutToLinalgMaps.empty()) { linalgGridSubTileMaps = prependUpperViews( - rewriter, linalgGridSubTileMaps, GemmOutToLinalgMaps); + rewriter, linalgGridSubTileMaps, gemmOutToLinalgMaps); } for (auto [idx, otherInput] : @@ -1516,13 +1516,13 @@ struct GridwiseAttentionAccelRewritePattern ArrayAttr linalgToOtherInputMaps; std::tie(std::ignore, linalgToOtherInputMaps, std::ignore) = untransform(rewriter, genOpInput); - ArrayAttr GemmOutToOtherInputMaps = linalgGridSubTileMaps; + ArrayAttr gemmOutToOtherInputMaps = linalgGridSubTileMaps; if (!linalgToOtherInputMaps.empty()) { - GemmOutToOtherInputMaps = prependUpperViews( + gemmOutToOtherInputMaps = prependUpperViews( rewriter, linalgGridSubTileMaps, linalgToOtherInputMaps); } rewriter.create( - loc, otherInput, tileBuffer, GemmOutToOtherInputMaps, + loc, otherInput, tileBuffer, gemmOutToOtherInputMaps, ValueRange{gridCoords.g_block, gridCoords.m_block, gridCoords.n_block, tid}, true, true); @@ -1841,33 +1841,30 @@ struct GridwiseAttentionAccelRewritePattern Value accRegBufferGemm0 = createBufferForAccelGemmOut(loc, accelParamsGemm0, rewriter); // Currently, there is a working assumption that this kernel is meant - // support fp32/fp16/bf16. This should be guranteed by op verifiers. + // support fp32/fp16/bf16. This should be guaranteed by op verifiers. Type gemmOutElemType = elemTypeQxK; - Type softmaxInElemType = elemTypeQxK; + Type fusionOutElemType = elemTypeQxK; if (elemTypeQ == rewriter.getI8Type()) { gemmOutElemType = rewriter.getI32Type(); } Value gemm0OutBuffer = createBufferForGemmOut(loc, gemmOutElemType, accelParamsGemm0, rewriter); - Value softmaxInBuffer = createBufferForGemmOut(loc, softmaxInElemType, - accelParamsGemm0, rewriter); - - // Buffers for reductions SmallVector bidGridOrder = {"g_block", "m_block", "n_block"}; - Value gemm0OutBufferMax = - createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm0, rewriter); - Value gemm0OutBufferExp = - createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm0, rewriter); - Value gemm0OutBufferSum = - createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm0, rewriter); - + Value fusionOutBuffer = createBufferForGemmOut(loc, fusionOutElemType, + accelParamsGemm0, rewriter); + // Buffers for reductions and softmax input + Value gemm0OutBufferMax, gemm0OutBufferExp, gemm0OutBufferSum; + if (op.getEnableSoftmax()) { + gemm0OutBufferMax = + createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm0, rewriter); + gemm0OutBufferExp = + createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm0, rewriter); + gemm0OutBufferSum = + createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm0, rewriter); + } // Buffers for gemm 1 - Value gemm1RegBufferB = gemm0OutBufferExp; -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - llvm::errs() << "Lowering attention op as a gemm-gemm op...\n"; - gemm1RegBufferB = gemm0OutBuffer; -#endif + Value gemm1RegBufferB; if (elemTypeV != elemTypeQxK) { gemm1RegBufferB = createBufferForGemmOut(loc, elemTypeV, accelParamsGemm0, rewriter); @@ -1876,18 +1873,20 @@ struct GridwiseAttentionAccelRewritePattern createBufferForGemmOut(loc, elemTypeV, accelParamsGemm0, rewriter); auto [preAccelRegBufferV, preAccelRegBufferQxK] = createRegInterrimBufferForAccel(loc, accelParamsGemm1, rewriter); - Value accRegBufferGemm1 = - createBufferForAccelGemmOut(loc, accelParamsGemm1, rewriter); -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - accRegBufferGemm1 = createBufferForAccelGemmOut(loc, accelParamsGemm1, - rewriter, gemm1MBlocks); -#endif - Value gemm1OutBuffer = - createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm1, rewriter); -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - gemm1OutBuffer = createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm1, - rewriter, gemm1MBlocks); -#endif + + Value accRegBufferGemm1; + Value gemm1OutBuffer; + if (op.getEnableSoftmax()) { + accRegBufferGemm1 = + createBufferForAccelGemmOut(loc, accelParamsGemm1, rewriter); + gemm1OutBuffer = + createBufferForGemmOut(loc, elemTypeQxK, accelParamsGemm1, rewriter); + } else { + accRegBufferGemm1 = createBufferForAccelGemmOut(loc, accelParamsGemm1, + rewriter, gemm1MBlocks); + gemm1OutBuffer = createBufferForGemmOut( + loc, elemTypeQxK, accelParamsGemm1, rewriter, gemm1MBlocks); + } SmallVector gemm1BidGridLengths = {gemm0G, gemm1MBlocks, gemm1NBlocks}; @@ -1923,38 +1922,45 @@ struct GridwiseAttentionAccelRewritePattern // o buffer; this is exactly same as gemm1OutBuffer; // we just need another buffer to do the special accumulation - Value attentionOutAccBuffer = createBufferForGemmOut( - loc, elemTypeQxK, accelParamsGemm1, rewriter, gemm1MBlocks); - Value attentionOutAccBufferOutTyped = attentionOutAccBuffer; - if (elemTypeQxK != elemTypeOut) { - attentionOutAccBufferOutTyped = - createBufferForGemmOut(loc, elemTypeOut, accelParamsGemm1, rewriter); - } - ArrayAttr attentionOutAccBufferThreadSubTileViewMaps = - invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile); - // m buffer; this only contains a reduced single value per row - auto reducedBufferType = - MemRefType::get({gemm1MPerThread}, elemTypeQxK, AffineMap{}, - /*memorySpace=*/privateMemoryAddressSpace); - auto negInfSumTyped = - createConstantFloatOp(rewriter, loc, reducedBufferType.getElementType(), - reducedBufferType.getElementType(), - -std::numeric_limits::infinity()); - auto maxRowBuffer = - rewriter.create(loc, reducedBufferType); - auto expMaxDiffRowBuffer = - rewriter.create(loc, reducedBufferType); - rewriter.create(loc, maxRowBuffer, negInfSumTyped); - // l buffer; this only contains a reduced single value per row - Value sumRowBuffer = - rewriter.create(loc, reducedBufferType); - rewriter.create(loc, sumRowBuffer, - createZeroConstantOp(rewriter, loc, elemTypeQxK)); - - zeroAccBuffer(rewriter, loc, attentionOutAccBuffer); -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - zeroAccBuffer(rewriter, loc, accRegBufferGemm1); -#endif + Value attentionOutAccBuffer, outAccBufferOutTyped, sumRowBuffer, + maxRowBuffer, expMaxDiffRowBuffer; + ArrayAttr attentionOutAccBufferThreadSubTileViewMaps; + if (op.getEnableSoftmax()) { + attentionOutAccBuffer = createBufferForGemmOut( + loc, elemTypeQxK, accelParamsGemm1, rewriter, gemm1MBlocks); + outAccBufferOutTyped = attentionOutAccBuffer; + if (elemTypeQxK != elemTypeOut) { + outAccBufferOutTyped = createBufferForGemmOut( + loc, elemTypeOut, accelParamsGemm1, rewriter); + } + attentionOutAccBufferThreadSubTileViewMaps = + invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile); + // m buffer; this only contains a reduced single value per row + auto reducedBufferType = + MemRefType::get({gemm1MPerThread}, elemTypeQxK, AffineMap{}, + /*memorySpace=*/privateMemoryAddressSpace); + auto negInfSumTyped = createConstantFloatOp( + rewriter, loc, reducedBufferType.getElementType(), + reducedBufferType.getElementType(), + -std::numeric_limits::infinity()); + maxRowBuffer = rewriter.create(loc, reducedBufferType); + expMaxDiffRowBuffer = + rewriter.create(loc, reducedBufferType); + rewriter.create(loc, maxRowBuffer, negInfSumTyped); + // l buffer; this only contains a reduced single value per row + sumRowBuffer = rewriter.create(loc, reducedBufferType); + rewriter.create(loc, sumRowBuffer, + createZeroConstantOp(rewriter, loc, elemTypeQxK)); + + zeroAccBuffer(rewriter, loc, attentionOutAccBuffer); + } else { + outAccBufferOutTyped = gemm1OutBuffer; + if (elemTypeQxK != elemTypeOut) { + outAccBufferOutTyped = createBufferForGemmOut( + loc, elemTypeOut, accelParamsGemm1, rewriter, gemm1MBlocks); + } + zeroAccBuffer(rewriter, loc, accRegBufferGemm1); + } // If gemm0K is equal to gemm0KPerBlock that means // effectively there is no K loop. Therefore, we // can prefetch the Q tile into regs outside of the @@ -1976,7 +1982,7 @@ struct GridwiseAttentionAccelRewritePattern fromGlobalRegBufferQ, toLDSRegBufferQ, preAccelRegBuffersQ, "n", gemm0kpack, gemm0KpacksPerBlock, gemm0NPerBlock, blockSize, gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter, - *accelEmitterPtrGemm0.get(), ldsLayoutCfgNG0); + *accelEmitterPtrGemm0, ldsLayoutCfgNG0); if (failed(statusLoadQTile)) { return failure(); } @@ -1988,7 +1994,7 @@ struct GridwiseAttentionAccelRewritePattern fromGlobalRegBufferQ, toLDSRegBufferQ, ldsByteBufferQ, "n", gemm0kpack, gemm0KpacksPerBlock, gemm0NPerBlock, blockSize, gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter, - *accelEmitterPtrGemm0.get(), ldsLayoutCfgNG0); + *accelEmitterPtrGemm0, ldsLayoutCfgNG0); if (failed(statusLoadQ)) { return failure(); } @@ -1996,10 +2002,10 @@ struct GridwiseAttentionAccelRewritePattern TypedValue ldsTileBufferQ = viewBufferAs( rewriter, ldsByteBufferQ, vectorTypeOrSelf(elemTypeQ, gemm0kpack)); - loadGemmOperandsFromLDSToRegs( - rewriter, loc, ldsTileBufferQ, preAccelRegBuffersQ, "n", blockSize, - gemm0InNPerThread, *accelEmitterPtrGemm0.get(), - ldsLayoutCfgNG0.doRotateWithK); + loadGemmOperandsFromLDSToRegs(rewriter, loc, ldsTileBufferQ, + preAccelRegBuffersQ, "n", blockSize, + gemm0InNPerThread, *accelEmitterPtrGemm0, + ldsLayoutCfgNG0.doRotateWithK); rewriter.create(loc, ldsByteBufferQ); } } @@ -2120,7 +2126,7 @@ struct GridwiseAttentionAccelRewritePattern toLDSRegBufferQ, ldsByteBufferQ, "n", gemm0kpack, gemm0KpacksPerBlock, gemm0NPerBlock, blockSize, gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter, - *accelEmitterPtrGemm0.get(), ldsLayoutCfgNG0); + *accelEmitterPtrGemm0, ldsLayoutCfgNG0); if (failed(statusLoadQ)) { return failure(); } @@ -2135,7 +2141,7 @@ struct GridwiseAttentionAccelRewritePattern toLDSRegBufferK, ldsByteBufferK, "m", gemm0kpack, gemm0KpacksPerBlock, gemm0MPerBlock, blockSize, gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter, - *accelEmitterPtrGemm0.get(), ldsLayoutCfgMG0); + *accelEmitterPtrGemm0, ldsLayoutCfgMG0); if (failed(statusLoadKTile)) { return failure(); } @@ -2224,99 +2230,109 @@ struct GridwiseAttentionAccelRewritePattern // Align the preSoftmaxElementWise (if any) linalg.generic to // be performed on the output of the first gemm. - FailureOr maybeSoftmaxInBuffer = postProcessFirstGemm( - rewriter, loc, op, gridCoordsGemm0, gemm0OutBuffer, softmaxInBuffer, + FailureOr maybeFusionOutBuffer = postProcessFirstGemm( + rewriter, loc, op, gridCoordsGemm0, gemm0OutBuffer, fusionOutBuffer, gemm0OutSubTileViewsTrUnPadded); - if (failed(maybeSoftmaxInBuffer)) { + if (failed(maybeFusionOutBuffer)) { return op.emitError("post processing first gemm failed.\n"); } - gemm0OutBuffer = maybeSoftmaxInBuffer.value(); - // Scale gemm0 output by (1/ln2) - // So that we can use exp2 instead of exp. -#ifndef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - Value ln2Recip = createConstantFloatOp( - rewriter, loc, elemTypeQxK, elemTypeQxK, 1.44269504f, - elemTypeQxK.isF32() ? APFloat::opOK : APFloat::opInexact); - postProcessFirstGemmSplat( - rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, gemm0OutSubTileViews, - ln2Recip.getDefiningOp().getValue()); - - // Handle padding - bool hasPadding = - op.getPrePadG0M().has_value() || op.getPrePadG0N().has_value(); - if (hasPadding) { - bool isGfx11 = arch.contains("gfx11"); - createFirstGemmNegInfPadding(rewriter, loc, gridCoordsGemm0, - gemm0OutBuffer, - gemm0OutSubTileViewsTrUnPadded, isGfx11); + gemm0OutBuffer = maybeFusionOutBuffer.value(); + + // Softmax + if (op.getEnableSoftmax()) { + // Scale gemm0 output by (1/ln2) + // So that we can use exp2 instead of exp. + Value ln2Recip = createConstantFloatOp( + rewriter, loc, elemTypeQxK, elemTypeQxK, 1.44269504f, + elemTypeQxK.isF32() ? APFloat::opOK : APFloat::opInexact); + postProcessFirstGemmSplat( + rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, + gemm0OutSubTileViews, + ln2Recip.getDefiningOp().getValue()); + + // Handle padding + bool hasPadding = + op.getPrePadG0M().has_value() || op.getPrePadG0N().has_value(); + if (hasPadding) { + bool isGfx11 = arch.contains("gfx11"); + createFirstGemmNegInfPadding(rewriter, loc, gridCoordsGemm0, + gemm0OutBuffer, + gemm0OutSubTileViewsTrUnPadded, isGfx11); + } + // Negative Infinite for extra values (KV cache) + setGemm0OutputOutOfScopeKVCache(rewriter, loc, gridCoordsGemm0, + gemm0OutBuffer, gemm0OutSubTileViewsTr, + currentSeqLen, mLoopIV, + gemm0MBlocksLastIter); + + APInt reductionAxis = APInt(64, 1); + // Softmax max reduction + Value ldsReductionWorkspaceByteBuffer = createLDSByteBuffer( + rewriter, loc, reductionWorkspaceSize, elemTypeQxK); + TypedValue ldsReductionWorkspaceBuffer = viewBufferAs( + rewriter, ldsReductionWorkspaceByteBuffer, elemTypeQxK); + rewriter.create( + loc, gemm0OutBuffer, ldsReductionWorkspaceBuffer, gemm0OutBufferMax, + /*extraOut=*/nullptr, reductionAxis, rock::ReduceMethod::Max, + gemm0OutSubTileViewsTr.blockSubTile, + gemm0OutSubTileViewsTr.blockSubTileTidSlice.value(), + gemm0OutSubTileViewsTr.threadSubTile, /*extraViews=*/nullptr, + blockSize); + rewriter.create(loc, ldsReductionWorkspaceByteBuffer); + + // softmax normalization. + Value gemm0MNThreadwiseView = + transform(rewriter, gemm0OutBuffer, + invertTransforms(rewriter, loc, + gemm0OutSubTileViewsTr.threadSubTile)); + Value gemm0MNExpThreadwiseView = + transform(rewriter, gemm0OutBufferExp, + invertTransforms(rewriter, loc, + gemm0OutSubTileViewsTr.threadSubTile)); + Value gemm0MNMaxThreadwiseView = + transform(rewriter, gemm0OutBufferMax, + invertTransforms(rewriter, loc, + gemm0OutSubTileViewsTr.threadSubTile)); + expSubstractMaxFromGemm0(rewriter, loc, gemm0MNThreadwiseView, + gemm0MNExpThreadwiseView, + gemm0MNMaxThreadwiseView, maxRowBuffer); + + // Softmax sum reduction + Value ldsReductionWorkspaceByteSecondBuffer = createLDSByteBuffer( + rewriter, loc, reductionWorkspaceSize, elemTypeQxK); + TypedValue ldsReductionWorkspaceSecondBuffer = viewBufferAs( + rewriter, ldsReductionWorkspaceByteSecondBuffer, elemTypeQxK); + rewriter.create( + loc, gemm0OutBufferExp, ldsReductionWorkspaceSecondBuffer, + gemm0OutBufferSum, /*extraOut=*/nullptr, reductionAxis, + rock::ReduceMethod::Sum, gemm0OutSubTileViewsTr.blockSubTile, + gemm0OutSubTileViewsTr.blockSubTileTidSlice.value(), + gemm0OutSubTileViewsTr.threadSubTile, + /*extraViews=*/nullptr, blockSize); + rewriter.create(loc, + ldsReductionWorkspaceByteSecondBuffer); + Value gemm0SumThreadwiseView = + transform(rewriter, gemm0OutBufferSum, + invertTransforms(rewriter, loc, + gemm0OutSubTileViewsTr.threadSubTile)); + Value gemm0MaxThreadwiseView = + transform(rewriter, gemm0OutBufferMax, + invertTransforms(rewriter, loc, + gemm0OutSubTileViewsTr.threadSubTile)); + updateRowSum(rewriter, loc, gemm0SumThreadwiseView, + gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer, + expMaxDiffRowBuffer); } - // Negative Infinite for extra values (KV cache) - setGemm0OutputOutOfScopeKVCache( - rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, - gemm0OutSubTileViewsTr, currentSeqLen, mLoopIV, gemm0MBlocksLastIter); -#endif - - APInt reductionAxis = APInt(64, 1); - APInt nrDimPerThread = APInt(64, gemm0MPerBlock / gemm0MPerThread); - - Value ldsReductionWorkspaceByteBuffer = createLDSByteBuffer( - rewriter, loc, reductionWorkspaceSize, elemTypeQxK); - TypedValue ldsReductionWorkspaceBuffer = - viewBufferAs(rewriter, ldsReductionWorkspaceByteBuffer, elemTypeQxK); - rewriter.create( - loc, gemm0OutBuffer, ldsReductionWorkspaceBuffer, gemm0OutBufferMax, - /*extraOut=*/nullptr, reductionAxis, rock::ReduceMethod::Max, - gemm0OutSubTileViewsTr.blockSubTile, - gemm0OutSubTileViewsTr.blockSubTileTidSlice.value(), - gemm0OutSubTileViewsTr.threadSubTile, /*extraViews=*/nullptr, - blockSize); - rewriter.create(loc, ldsReductionWorkspaceByteBuffer); - // softmax normalization. - Value gemm0MNThreadwiseView = - transform(rewriter, gemm0OutBuffer, - invertTransforms(rewriter, loc, - gemm0OutSubTileViewsTr.threadSubTile)); - Value gemm0MNExpThreadwiseView = - transform(rewriter, gemm0OutBufferExp, - invertTransforms(rewriter, loc, - gemm0OutSubTileViewsTr.threadSubTile)); - Value gemm0MNMaxThreadwiseView = - transform(rewriter, gemm0OutBufferMax, - invertTransforms(rewriter, loc, - gemm0OutSubTileViewsTr.threadSubTile)); - expSubstractMaxFromGemm0(rewriter, loc, gemm0MNThreadwiseView, - gemm0MNExpThreadwiseView, - gemm0MNMaxThreadwiseView, maxRowBuffer); - - Value ldsReductionWorkspaceByteSecondBuffer = createLDSByteBuffer( - rewriter, loc, reductionWorkspaceSize, elemTypeQxK); - TypedValue ldsReductionWorkspaceSecondBuffer = viewBufferAs( - rewriter, ldsReductionWorkspaceByteSecondBuffer, elemTypeQxK); - rewriter.create( - loc, gemm0OutBufferExp, ldsReductionWorkspaceSecondBuffer, - gemm0OutBufferSum, /*extraOut=*/nullptr, reductionAxis, - rock::ReduceMethod::Sum, gemm0OutSubTileViewsTr.blockSubTile, - gemm0OutSubTileViewsTr.blockSubTileTidSlice.value(), - gemm0OutSubTileViewsTr.threadSubTile, - /*extraViews=*/nullptr, blockSize); - rewriter.create(loc, ldsReductionWorkspaceByteSecondBuffer); - Value gemm0SumThreadwiseView = - transform(rewriter, gemm0OutBufferSum, - invertTransforms(rewriter, loc, - gemm0OutSubTileViewsTr.threadSubTile)); - Value gemm0MaxThreadwiseView = - transform(rewriter, gemm0OutBufferMax, - invertTransforms(rewriter, loc, - gemm0OutSubTileViewsTr.threadSubTile)); - updateRowSum(rewriter, loc, gemm0SumThreadwiseView, - gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer, - expMaxDiffRowBuffer); // Emit blockwise GEMM 1. { + auto gemm0Out = + op.getEnableSoftmax() ? gemm0OutBufferExp : gemm0OutBuffer; if (elemTypeV != elemTypeQxK) { - createTypeConversionLaGeneric(rewriter, loc, gemm0OutBufferExp, + createTypeConversionLaGeneric(rewriter, loc, gemm0Out, gemm1RegBufferB); + } else { + gemm1RegBufferB = gemm0Out; } Value wrappedLDSBufferForLoadB; Value gemm1LDSByteBufferB; @@ -2360,14 +2376,14 @@ struct GridwiseAttentionAccelRewritePattern OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(g1MLoopOp.getBody()); Value g1MLoopIndVar = g1MLoopOp.getInductionVar(); -#ifndef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - zeroAccBuffer(rewriter, loc, accRegBufferGemm1); -#else - if (gemm1MBlocks > 1) { - accRegBufferGemm1 = createSliceOfFirstDim( - rewriter, loc, accRegBufferGemm1, g1MLoopIndVar); + if (op.getEnableSoftmax()) { + zeroAccBuffer(rewriter, loc, accRegBufferGemm1); + } else { + if (gemm1MBlocks > 1) { + accRegBufferGemm1 = createSliceOfFirstDim( + rewriter, loc, accRegBufferGemm1, g1MLoopIndVar); + } } -#endif auto gridCoordsGemm1 = layout::makeGxNGridLayout( rewriter, loc, bid, g1MLoopIndVar, gemm1NBlocks, gridSize, arch); @@ -2379,7 +2395,7 @@ struct GridwiseAttentionAccelRewritePattern toLDSRegBufferV, ldsByteBufferV, "m", gemm1kpack, gemm1KpacksPerBlock, gemm1MPerBlock, blockSize, gridSize, bidGridOrder, gemm1BidGridLengths, forceUnroll, rewriter, - *accelEmitterPtrGemm1.get(), ldsLayoutCfgMG1); + *accelEmitterPtrGemm1, ldsLayoutCfgMG1); if (failed(statusLoadVTile)) { return failure(); } @@ -2460,37 +2476,38 @@ struct GridwiseAttentionAccelRewritePattern // There is no second k-loop // Therefore can get the output straight away Value gemm1OutBufferPerG1MBlock = gemm1OutBuffer; -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - if (gemm1MBlocks > 1) { + if (!op.getEnableSoftmax() && gemm1MBlocks > 1) { gemm1OutBufferPerG1MBlock = createSliceOfFirstDim( rewriter, loc, gemm1OutBuffer, g1MLoopIndVar); } -#endif + accelEmitterPtrGemm1->computeOutputConversion( rewriter, loc, accRegBufferGemm1, gemm1OutBufferPerG1MBlock, forceUnroll); - Value attentionOutAccBufferPerG1MBlock = attentionOutAccBuffer; - if (gemm1MBlocks > 1) { - attentionOutAccBufferPerG1MBlock = createSliceOfFirstDim( - rewriter, loc, attentionOutAccBuffer, g1MLoopIndVar); + if (op.getEnableSoftmax()) { + Value attentionOutAccBufferPerG1MBlock = attentionOutAccBuffer; + if (gemm1MBlocks > 1) { + attentionOutAccBufferPerG1MBlock = createSliceOfFirstDim( + rewriter, loc, attentionOutAccBuffer, g1MLoopIndVar); + } + ArrayAttr invertedGemm1threadSubTileMaps = invertTransforms( + rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile); + Value gemm1MNThreadwiseView = + transform(rewriter, gemm1OutBufferPerG1MBlock, + invertedGemm1threadSubTileMaps); + // Rescale/correct output, rowMax and rowSums + Value attentionOutAccBufferView = + transform(rewriter, attentionOutAccBufferPerG1MBlock, + attentionOutAccBufferThreadSubTileViewMaps); + createAttentionRowStateCorrections( + rewriter, loc, gemm1MNThreadwiseView, attentionOutAccBufferView, + expMaxDiffRowBuffer); } - ArrayAttr invertedGemm1threadSubTileMaps = invertTransforms( - rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile); - Value gemm1MNThreadwiseView = - transform(rewriter, gemm1OutBufferPerG1MBlock, - invertedGemm1threadSubTileMaps); - // Rescale/correct output, rowMax and rowSums - Value attentionOutAccBufferView = - transform(rewriter, attentionOutAccBufferPerG1MBlock, - attentionOutAccBufferThreadSubTileViewMaps); - createAttentionRowStateCorrections( - rewriter, loc, gemm1MNThreadwiseView, attentionOutAccBufferView, - expMaxDiffRowBuffer); } } } - { + if (op.getEnableSoftmax()) { affine::AffineForOp g1MLoopOp = rewriter.create(loc, 0, gemm1MBlocks, 1); { @@ -2509,31 +2526,28 @@ struct GridwiseAttentionAccelRewritePattern sumRowBuffer); } } + Value outAccBuffer = + op.getEnableSoftmax() ? attentionOutAccBuffer : gemm1OutBuffer; if (elemTypeQxK != elemTypeOut) { - createTypeConversionLaGeneric(rewriter, loc, attentionOutAccBuffer, - attentionOutAccBufferOutTyped); + createTypeConversionLaGeneric(rewriter, loc, outAccBuffer, + outAccBufferOutTyped); } -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - attentionOutAccBufferOutTyped = gemm1OutBuffer; -#endif + // We flatten output buffer in case gemm1MBlocks > 1 // where those are iterated. - Value attentionOutAccBufferOutTypedFlat = attentionOutAccBufferOutTyped; - MemRefType attentionOutAccBufferOutType = - cast(attentionOutAccBufferOutTyped.getType()); - int64_t numElementsAttnOut = attentionOutAccBufferOutType.getNumElements(); - if (attentionOutAccBufferOutType.getRank() > 1) { - Type attentionOutAccBufferOutTypedElType = - attentionOutAccBufferOutType.getElementType(); - auto attentionOutAccBufferOutTypedFlatType = MemRefType::get( - {numElementsAttnOut}, attentionOutAccBufferOutTypedElType, - AffineMap{}, privateMemoryAddressSpace); - auto reassociation = - getReassociationForFlattening(attentionOutAccBufferOutType); - attentionOutAccBufferOutTypedFlat = - rewriter.create( - loc, attentionOutAccBufferOutTypedFlatType, - attentionOutAccBufferOutTyped, reassociation); + Value outAccBufferOutTypedFlat = outAccBufferOutTyped; + MemRefType outAccBufferOutType = + cast(outAccBufferOutTyped.getType()); + int64_t numElementsAttnOut = outAccBufferOutType.getNumElements(); + if (outAccBufferOutType.getRank() > 1) { + Type outAccBufferOutTypedElType = outAccBufferOutType.getElementType(); + auto outAccBufferOutTypedFlatType = + MemRefType::get({numElementsAttnOut}, outAccBufferOutTypedElType, + AffineMap{}, privateMemoryAddressSpace); + auto reassociation = getReassociationForFlattening(outAccBufferOutType); + outAccBufferOutTypedFlat = rewriter.create( + loc, outAccBufferOutTypedFlatType, outAccBufferOutTyped, + reassociation); } // This map will create an upper view [gblock, nblock, flatiter] -> [gblock, // miter, nblock, iter] @@ -2547,7 +2561,7 @@ struct GridwiseAttentionAccelRewritePattern auto gridCoordsGemm1 = layout::makeGxNGridLayout( rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch); rewriter.create( - loc, attentionOutAccBufferOutTypedFlat, trOut, outGridSubTile, + loc, outAccBufferOutTypedFlat, trOut, outGridSubTile, /*extraIndices=*/ ValueRange{gridCoordsGemm1.g_block, gridCoordsGemm1.n_block, tid}, op.getFeatures(), rock::StoreMethod::Set, forceUnroll, diff --git a/mlir/lib/Dialect/Rock/Transforms/Regularize.cpp b/mlir/lib/Dialect/Rock/Transforms/Regularize.cpp index 68f472edb432..93dcd2f89596 100644 --- a/mlir/lib/Dialect/Rock/Transforms/Regularize.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/Regularize.cpp @@ -606,7 +606,7 @@ LogicalResult findFusionRoots(func::FuncOp kernel, collectInputFusionWriteOperands(readOperand, bufferDeps, state); } - if (isa(op)) { + if (isa(op)) { // The linalg.generic inside the attention's body will be expected to // write out a global tensor as if it were an output fusion, so its // write should be added to the set of output fusion writes lest we diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 78aff40a6a3a..95ba25d10d0d 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -11,6 +11,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h" +#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h" #include "mlir/Dialect/Rock/IR/RockTuningParamAttrInterface.h" #include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h" #include "mlir/Dialect/Rock/Tuning/RockTuning.h" @@ -28,8 +31,8 @@ namespace mlir { namespace rock { // The full space is a brute-force search for attention kernels -static void createAttnTuningRangeBF(TuningParamSet *newSpace, - AttentionOp attnOp, +template +static void createAttnTuningRangeBF(TuningParamSet *newSpace, Op attnOp, TuningParamSetKind kind) { static const std::vector> validRangeAttnParamsMFMA = { /*gemm0MPerBlock=*/{32, 64, 128, 256}, @@ -47,7 +50,7 @@ static void createAttnTuningRangeBF(TuningParamSet *newSpace, /*mPerWave=*/{32, 64}, /*nPerWave=*/{32, 64}, /*kPack=*/{4, 8, 16}}; - GemmFeatures features = attnOp.getFeatures(); + GemmFeatures features = attnOp.getGemmFeatures(); int64_t numEUPerCU = rock::lookupArchInfo(attnOp.getArch()).numEUPerCU; std::vector> validRangeAttnParams; bool isWMMA = false; @@ -401,12 +404,11 @@ static void createQuickTuningRange(TuningParamSet *newSpace, // This is temporary workaround to make MIGraphX integration // work until the tuning is setup for attention ops properly. -static void createAttnTuningRangeQuick(TuningParamSet *newSpace, - AttentionOp attnOp) { +template +static void createAttnTuningRangeQuick(TuningParamSet *newSpace, Op attnOp, + Type elemType) { OpBuilder b(attnOp.getContext()); - GemmFeatures currentFeatures = attnOp.getFeatures(); - Type elemType = - cast(attnOp.getQueries().getType()).getElementType(); + GemmFeatures currentFeatures = attnOp.getGemmFeatures(); // g0Mpb, g1Mpb, g0Npb, Kpb, mPw, mnPxdl, kpack using PerfConfigVals = std::tuple; @@ -482,20 +484,22 @@ TuningParamSet *createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind) { newSpace->primaryOpType = op.getKernelType(); return WalkResult::interrupt(); }); - WalkResult findAttention = mod->walk([&](rock::AttentionOp op) -> WalkResult { - switch (kind) { - case TuningParamSetKind::Full: - case TuningParamSetKind::Exhaustive: - createAttnTuningRangeBF(newSpace, op, kind); - break; - case TuningParamSetKind::Quick: - createAttnTuningRangeQuick(newSpace, op); - } - return WalkResult::interrupt(); - }); - if (!findPrimary.wasInterrupted() && !findAttention.wasInterrupted()) { - llvm::report_fatal_error( - "Expected to find GEMM, convolution, or attention op, and didn't."); + WalkResult findGemmGemm = + mod->walk([&](rock::RockGemmGemmWrapperInterface op) -> WalkResult { + Type elemType = cast(op.getAType()).getElementType(); + switch (kind) { + case TuningParamSetKind::Full: + case TuningParamSetKind::Exhaustive: + createAttnTuningRangeBF(newSpace, op, kind); + break; + case TuningParamSetKind::Quick: + createAttnTuningRangeQuick(newSpace, op, elemType); + } + return WalkResult::interrupt(); + }); + if (!findPrimary.wasInterrupted() && !findGemmGemm.wasInterrupted()) { + llvm::report_fatal_error("Expected to find GEMM, convolution, attention or " + "gemm+gemm op, and didn't."); } return newSpace; } @@ -519,15 +523,16 @@ bool tuningSetParam(ModuleOp &mod, ParamEntry *paramEntry) { op->setAttr("perf_config", attr); return WalkResult::interrupt(); }); - WalkResult setAttn = mod->walk([&](rock::AttentionOp op) -> WalkResult { - auto *ctx = op.getContext(); - SmallString<64> perfConfig; - paramEntry->param.getPerfConfigStr(perfConfig); - StringAttr attr = StringAttr::get(ctx, perfConfig); - op->setAttr("perf_config", attr); - return WalkResult::interrupt(); - }); - return setPrimary.wasInterrupted() || setAttn.wasInterrupted(); + WalkResult setGemmGemm = + mod->walk([&](rock::RockGemmGemmWrapperInterface op) -> WalkResult { + auto *ctx = op.getContext(); + SmallString<64> perfConfig; + paramEntry->param.getPerfConfigStr(perfConfig); + StringAttr attr = StringAttr::get(ctx, perfConfig); + op->setAttr("perf_config", attr); + return WalkResult::interrupt(); + }); + return setPrimary.wasInterrupted() || setGemmGemm.wasInterrupted(); } bool tuningSetStr(ModuleOp &mod, StringRef perfConfig) { @@ -538,13 +543,14 @@ bool tuningSetStr(ModuleOp &mod, StringRef perfConfig) { op->setAttr("perf_config", attr); return WalkResult::interrupt(); }); - WalkResult setAttn = mod->walk([&](rock::AttentionOp op) -> WalkResult { - auto *ctx = op.getContext(); - StringAttr attr = StringAttr::get(ctx, perfConfig); - op->setAttr("perf_config", attr); - return WalkResult::interrupt(); - }); - return setPrimary.wasInterrupted() || setAttn.wasInterrupted(); + WalkResult setGemmGemm = + mod->walk([&](rock::RockGemmGemmWrapperInterface op) -> WalkResult { + auto *ctx = op.getContext(); + StringAttr attr = StringAttr::get(ctx, perfConfig); + op->setAttr("perf_config", attr); + return WalkResult::interrupt(); + }); + return setPrimary.wasInterrupted() || setGemmGemm.wasInterrupted(); } TuningTable *tuningTableCreate() { @@ -552,11 +558,12 @@ TuningTable *tuningTableCreate() { return newTable; } -static LogicalResult getTuningProblemStr(rock::AttentionOp attnOp, - SmallVectorImpl &out) { - int32_t numCU = rock::lookupArchInfo(attnOp.getArch()).minNumCU; - if (attnOp.getNumCU().has_value()) { - numCU = attnOp.getNumCU().value(); +static LogicalResult +getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, + SmallVectorImpl &out) { + int32_t numCU = rock::lookupArchInfo(gemmGemmOp.getArch()).minNumCU; + if (gemmGemmOp.getNumCU().has_value()) { + numCU = gemmGemmOp.getNumCU().value(); } constexpr char sep = ' '; constexpr char tab = '\t'; @@ -566,35 +573,37 @@ static LogicalResult getTuningProblemStr(rock::AttentionOp attnOp, int64_t seqLenK; llvm::raw_svector_ostream problemOS(out); // ARCH string - problemOS << attnOp.getArch() << tab; + problemOS << gemmGemmOp.getArch() << tab; // Num of Compute Units problemOS << numCU << tab; - TypedValue queries = attnOp.getQueries(); - TypedValue keys = attnOp.getKeys(); - TypedValue values = attnOp.getValues(); - ArrayRef qShape = queries.getType().getShape(); - ArrayRef kShape = keys.getType().getShape(); - ArrayRef vShape = values.getType().getShape(); + ArrayRef qShape = cast(gemmGemmOp.getAType()).getShape(); + ArrayRef kShape = cast(gemmGemmOp.getBType()).getShape(); + ArrayRef vShape = cast(gemmGemmOp.getCType()).getShape(); int64_t g = qShape[0]; - Type elemTypeQ = queries.getType().getElementType(); + bool isAttention = isa(gemmGemmOp); + + Type elemTypeQ = cast(gemmGemmOp.getAType()).getElementType(); problemOS << "-t "; if (elemTypeQ.isF32()) { problemOS << "f32" << sep; - } else if (elemTypeQ.isF16()) { + } else if (elemTypeQ.isF16() && isAttention) { problemOS << "f16" << sep; - } else if (elemTypeQ.isBF16()) { + } else if (elemTypeQ.isBF16() && isAttention) { problemOS << "bf16" << sep; - } else if (elemTypeQ.isInteger(8)) { + } else if (elemTypeQ.isInteger(8) && isAttention) { problemOS << "i8" << sep; } else { - return attnOp.emitError("invalid type:") << elemTypeQ << "\n"; + return gemmGemmOp.emitError("invalid type:") << elemTypeQ << "\n"; } // TransQ - problemOS << "-transQ "; - if (attnOp.getQTransposed()) { + if (isAttention) + problemOS << "-transQ "; + else + problemOS << "-transA "; + if (gemmGemmOp.getTransposedA()) { seqLenQ = qShape[2]; headDimQK = qShape[1]; problemOS << "true" << sep; @@ -605,8 +614,11 @@ static LogicalResult getTuningProblemStr(rock::AttentionOp attnOp, } // TransK - problemOS << "-transK "; - if (attnOp.getKTransposed()) { + if (isAttention) + problemOS << "-transK "; + else + problemOS << "-transB "; + if (gemmGemmOp.getTransposedB()) { seqLenK = kShape[1]; problemOS << "true" << sep; } else { @@ -615,8 +627,11 @@ static LogicalResult getTuningProblemStr(rock::AttentionOp attnOp, } // TransV - problemOS << "-transV "; - if (attnOp.getVTransposed()) { + if (isAttention) + problemOS << "-transV "; + else + problemOS << "-transC "; + if (gemmGemmOp.getTransposedC()) { headDimV = vShape[1]; problemOS << "true" << sep; } else { @@ -626,16 +641,23 @@ static LogicalResult getTuningProblemStr(rock::AttentionOp attnOp, // TransO problemOS << "-transO "; - if (attnOp.getOTransposed()) + if (gemmGemmOp.getTransposedOut()) problemOS << "true" << sep; else problemOS << "false" << sep; problemOS << "-g " << g << sep; - problemOS << "-seq_len_q " << seqLenQ << sep; - problemOS << "-seq_len_k " << seqLenK << sep; - problemOS << "-head_dim_qk " << headDimQK << sep; - problemOS << "-head_dim_v " << headDimV; + if (isAttention) { + problemOS << "-seq_len_q " << seqLenQ << sep; + problemOS << "-seq_len_k " << seqLenK << sep; + problemOS << "-head_dim_qk " << headDimQK << sep; + problemOS << "-head_dim_v " << headDimV; + } else { + problemOS << "-m " << seqLenQ << sep; + problemOS << "-n " << seqLenK << sep; + problemOS << "-k " << headDimQK << sep; + problemOS << "-gemmO " << headDimV; + } return success(); } @@ -896,14 +918,14 @@ LogicalResult getTuningProblemStr(ModuleOp mod, SmallVectorImpl &out) { return getTuningProblemStr(gemmIF, out); } { - rock::AttentionOp attnOp; - WalkResult findAttention = - mod->walk([&](rock::AttentionOp op) -> WalkResult { - attnOp = op; + rock::RockGemmGemmWrapperInterface gemmGemmOp; + WalkResult findGemmGemm = + mod->walk([&](rock::RockGemmGemmWrapperInterface op) -> WalkResult { + gemmGemmOp = op; return WalkResult::interrupt(); }); - if (findAttention.wasInterrupted()) - return getTuningProblemStr(attnOp, out); + if (findGemmGemm.wasInterrupted()) + return getTuningProblemStr(gemmGemmOp, out); } return failure(); } diff --git a/mlir/test/Dialect/Rock/affix_tuning_params.mlir b/mlir/test/Dialect/Rock/affix_tuning_params.mlir index 35ba6ca4d30b..85f1374178a6 100644 --- a/mlir/test/Dialect/Rock/affix_tuning_params.mlir +++ b/mlir/test/Dialect/Rock/affix_tuning_params.mlir @@ -442,9 +442,9 @@ func.func @rock_attention_large(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x return } -// CHECK-LABEL: func.func @rock_attention_mperblockg1 +// CHECK-LABEL: func.func @rock_attention_mperblockg1_wmma // CHECK-SAME: block_size = 128 -// GRID-LABEL: func.func @rock_attention_mperblockg1 +// GRID-LABEL: func.func @rock_attention_mperblockg1_wmma // GRID-SAME: grid_size = 3 func.func @rock_attention_mperblockg1_wmma(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention @@ -457,9 +457,9 @@ func.func @rock_attention_mperblockg1_wmma(%arg0: memref<1x384x64xf16>, %arg1: m return } -// CHECK-LABEL: func.func @rock_attention_mperblockg1 +// CHECK-LABEL: func.func @rock_attention_mperblockg1_mfma // CHECK-SAME: block_size = 256 -// GRID-LABEL: func.func @rock_attention_mperblockg1 +// GRID-LABEL: func.func @rock_attention_mperblockg1_mfma // GRID-SAME: grid_size = 3 func.func @rock_attention_mperblockg1_mfma(%arg0: memref<1x384x64xf16>, %arg1: memref<1x384x64xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { // CHECK: rock.attention @@ -472,6 +472,37 @@ func.func @rock_attention_mperblockg1_mfma(%arg0: memref<1x384x64xf16>, %arg1: m return } +// CHECK-LABEL: func.func @rock_gemm_gemm_large +// CHECK-SAME: block_size = 256 +// GRID-LABEL: func.func @rock_gemm_gemm_large +// GRID-SAME: grid_size = 128 +func.func @rock_gemm_gemm_large(%arg0: memref<1x16384x512xf32>, %arg1: memref<1x512x16384xf32>, %arg2: memref<1x16384x512xf32>, %arg3: memref<1x16384x512xf32>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x16384x512xf32> + // CHECK: rock.gemm_elementwise_gemm + // CHECK: params0 = #rock.xdlops_gemm_derived_params + // CHECK: params1 = #rock.xdlops_gemm_derived_params + rock.gemm_elementwise_gemm{ + ab = %arg0 * %arg1 : memref<1x16384x512xf32>, memref<1x512x16384xf32> + %arg3 = ab * %arg2 : memref<1x16384x512xf32> -> memref<1x16384x512xf32> + } {arch = "gfx942:sramecc+:xnack-", features = #rock, perf_config = "attn:v1:128,128,128,2,64,64,8,1", firstGemmIdx = 0 : i32} + return +} + +// CHECK-LABEL: func.func @rock_gemm_gemm_mperblockg1_mfma +// CHECK-SAME: block_size = 256 +// GRID-LABEL: func.func @rock_gemm_gemm_mperblockg1_mfma +// GRID-SAME: grid_size = 3 +func.func @rock_gemm_gemm_mperblockg1_mfma(%arg0: memref<1x384x64xf32>, %arg1: memref<1x384x64xf32>, %arg2: memref<1x384x64xf32>, %arg3: memref<1x384x64xf32>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { + // CHECK: rock.gemm_elementwise_gemm + // CHECK: #rock.xdlops_gemm_derived_params + // CHECK: #rock.xdlops_gemm_derived_params + rock.gemm_elementwise_gemm{ + ab = %arg0 * tr %arg1 : memref<1x384x64xf32>, memref<1x384x64xf32> + %arg3 = ab * %arg2 : memref<1x384x64xf32> -> memref<1x384x64xf32> + } {arch = "gfx942:sramecc+:xnack-", features = #rock, perf_config = "attn:v1:128,256,128,2,64,64,8,1", firstGemmIdx = 0 : i32} + return +} + // CHECK-LABEL: func.func @rock_conv_tuning // GRID-LABEL: func.func @rock_conv_tuning func.func @rock_conv_tuning(%arg0: memref<1x1x1x3x3xf32>, %arg1: memref<64x1x1x14x14xf32>, %arg2: memref<64x1x1x14x14xf32>) attributes {kernel = 0 : i32, mhal.arch = "amdgcn-amd-amdhsa:gfx90a:sramecc+:xnack-"} { diff --git a/mlir/test/Dialect/Rock/affix_tuning_params_invalid.mlir b/mlir/test/Dialect/Rock/affix_tuning_params_invalid.mlir index 1508b722ff8f..832b0b5f4bc7 100644 --- a/mlir/test/Dialect/Rock/affix_tuning_params_invalid.mlir +++ b/mlir/test/Dialect/Rock/affix_tuning_params_invalid.mlir @@ -10,3 +10,12 @@ func.func @rock_attention_invalid_perf_config(%arg0: memref<1x384x64xf16>, %arg1 } {arch = "amdgcn-amd-amdhsa:gfx1100", features = #rock, perf_config = "attn:v1:128,128,16,8,32,64,8,1", firstGemmIdx = 0 : i32} return } + +func.func @rock_gemm_gemm_invalid_perf_config(%arg0: memref<1x384x64xf32>, %arg1: memref<1x384x64xf32>, %arg2: memref<1x384x64xf32>, %arg3: memref<1x384x64xf32>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} { + // expected-error @+1 {{The provided perf config is not valid}} + rock.gemm_elementwise_gemm{ + ab = %arg0 * tr %arg1 : memref<1x384x64xf32>, memref<1x384x64xf32> + %arg3 = ab * %arg2 : memref<1x384x64xf32> -> memref<1x384x64xf32> + } {arch = "amdgcn-amd-amdhsa:gfx1100", features = #rock, perf_config = "attn:v1:128,128,16,8,32,64,8,1", firstGemmIdx = 0 : i32} + return +} diff --git a/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir b/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir index c118f64190c9..e9d87f094e2b 100644 --- a/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir +++ b/mlir/test/Dialect/Rock/gemm_to_gridwise.mlir @@ -267,3 +267,45 @@ func.func @rock_attention_kvcache(%arg0: memref<1x64x1024xf32>, %arg1: memref<1x } return } + +// CHECK-LABEL: func.func @rock_gemmelementwisegemm_simple +// CHECK-SAME: (%[[a:.*]]: memref<1x64x1024xf32>, %[[b:.*]]: memref<1x64x1024xf32>, %[[c:.*]]: memref<1x1024x64xf32>, %[[o:.*]]: memref<1x1024x64xf32>) +func.func @rock_gemmelementwisegemm_simple(%arg0: memref<1x64x1024xf32>, %arg1: memref<1x64x1024xf32>, %arg2: memref<1x1024x64xf32>, %arg3: memref<1x1024x64xf32>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx908", block_size = 64 : i32, grid_size = 1024 : i32} { + // CHECK: rock.gridwise_attention_accel(%[[a]], %[[b]], %[[c]], %[[o]]) + // CHECK-NEXT: enableSoftmax = false + rock.gemm_elementwise_gemm{ + ab = tr %arg0 * %arg1 : memref<1x64x1024xf32>, memref<1x64x1024xf32> + %arg3 = ab * %arg2 : memref<1x1024x64xf32> -> memref<1x1024x64xf32> + } { + arch = "amdgcn-amd-amdhsa:gfx942:sramecc+:xnack-", + features = #rock, + params0 = #xldops_attn_params_g0, + params1 = #xldops_attn_params_g1, + firstGemmIdx = 0 : i32 + } + return +} + +// CHECK-LABEL: func.func @rock_gemmelementwisegemm_tr_padded +// CHECK-SAME: (%[[a:.*]]: memref<1x49x7xf32>, %[[b:.*]]: memref<1x7x49xf32>, %[[c:.*]]: memref<1x49x7xf32>, %[[o:.*]]: memref<1x49x7xf32>) +func.func @rock_gemmelementwisegemm_tr_padded(%arg0: memref<1x49x7xf32>, %arg1: memref<1x7x49xf32>, %arg2: memref<1x49x7xf32>, %arg3: memref<1x49x7xf32>) attributes {kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx908", block_size = 64 : i32, grid_size = 2 : i32} { + // CHECK-DAG: %[[trA:.*]] = rock.transform %[[a]] by {{.*}} : memref<1x49x7xf32> to memref<1x7x49xf32> + // CHECK-DAG: %[[paddedTrA:.*]] = rock.transform %[[trA]] by {{.*}} : memref<1x7x49xf32> to memref<1x8x64xf32> + // CHECK-DAG: %[[paddedB:.*]] = rock.transform %[[b]] by {{.*}} : memref<1x7x49xf32> to memref<1x8x64xf32> + // CHECK-DAG: %[[paddedC:.*]] = rock.transform %[[c]] by {{.*}} : memref<1x49x7xf32> to memref<1x64x32xf32> + // CHECK-DAG: %[[paddedO:.*]] = rock.transform %[[o]] by {{.*}} : memref<1x49x7xf32> to memref<1x64x32xf32> + // CHECK: rock.gridwise_attention_accel(%[[paddedTrA]], %[[paddedB]], %[[paddedC]], %[[paddedO]]) + // CHECK-NEXT: enableSoftmax = false + // CHECK-SAME: prePadG0M = 49 : index, prePadG0N = 49 : index + rock.gemm_elementwise_gemm{ + ab = %arg0 * %arg1 : memref<1x49x7xf32>, memref<1x7x49xf32> + %arg3 = ab * %arg2 : memref<1x49x7xf32> -> memref<1x49x7xf32> + } { + arch = "amdgcn-amd-amdhsa:gfx942:sramecc+:xnack-", + features = #rock, + params0 = #xldops_attn_params_g0, + params1 = #xldops_attn_params_g1, + firstGemmIdx = 0 : i32 + } + return +} diff --git a/mlir/test/Dialect/Rock/lowering_sort_dimensions_memory_layout.mlir b/mlir/test/Dialect/Rock/lowering_sort_dimensions_memory_layout.mlir index ea59d4099cee..3d3aa6492848 100644 --- a/mlir/test/Dialect/Rock/lowering_sort_dimensions_memory_layout.mlir +++ b/mlir/test/Dialect/Rock/lowering_sort_dimensions_memory_layout.mlir @@ -220,4 +220,4 @@ func.func @test_mlir_slice_add_literal_weights_convolution(%arg0: memref<1638400 %10 = rock.transform %9 by (0, 0, d0 floordiv 10240, (d0 mod 10240) floordiv 128, d0 mod 128)> by [ ["col0", "col1", "col2", "col3", "col4"] at [0, 1, 2, 3, 4]>] bounds = [819200] -> [1, 1, 80, 80, 128]> : memref<1x1x80x80x128xf16> to memref<819200xf16> memref.copy %10, %arg1 : memref<819200xf16> to memref<819200xf16> return -} \ No newline at end of file +} diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 9a5ee10ab774..441b6f233caf 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -33,6 +33,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrAttentionBF16 PrAttentionI8 PrGemmSplitK + PrGemmElementwiseGemmF32 ) set(GEN_MODE "") endif() diff --git a/mlir/test/e2e/PrAttentionF32.toml b/mlir/test/e2e/PrAttentionF32.toml index fcbe816ea5d2..6b46ed3ec914 100644 --- a/mlir/test/e2e/PrAttentionF32.toml +++ b/mlir/test/e2e/PrAttentionF32.toml @@ -2,26 +2,11 @@ directory = "PrAttentionF32" prefix = "rocmlir-gen" suffix = "--operation attention -t f32 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.00005 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" -[[axis]] -name = "transQ" -values = ["true", "false"] -prefix = "--transQ=" - [[axis]] name = "transK" values = ["true", "false"] prefix = "--transK=" -[[axis]] -name = "transV" -values = ["true", "false"] -prefix = "--transV=" - -[[axis]] -name = "transO" -values = ["true", "false"] -prefix = "--transO=" - ## attention variant [[suite]] name = "pr_attention_f32" diff --git a/mlir/test/e2e/PrGemmElementwiseGemmF32.cfg b/mlir/test/e2e/PrGemmElementwiseGemmF32.cfg new file mode 100644 index 000000000000..3d7864d49b56 --- /dev/null +++ b/mlir/test/e2e/PrGemmElementwiseGemmF32.cfg @@ -0,0 +1,2 @@ +if (not config.arch_support_mfma): + config.unsupported = True diff --git a/mlir/test/e2e/PrGemmElementwiseGemmF32.toml b/mlir/test/e2e/PrGemmElementwiseGemmF32.toml new file mode 100644 index 000000000000..f5521c91f188 --- /dev/null +++ b/mlir/test/e2e/PrGemmElementwiseGemmF32.toml @@ -0,0 +1,41 @@ +directory = "PrGemmElementwiseGemmF32" +prefix = "rocmlir-gen" +suffix = "--operation gemm_gemm -t f32 --arch %arch -pv %random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "transA" +values = ["true", "false"] +prefix = "--transA=" + +[[axis]] +name = "transB" +values = ["true", "false"] +prefix = "--transB=" + +[[axis]] +name = "transC" +values = ["true", "false"] +prefix = "--transC=" + +[[axis]] +name = "transO" +values = ["true", "false"] +prefix = "--transO=" + +## gemm+gemm variant +[[suite]] +name = "pr_gemm_gemm_f32" + +[[suite.test]] +config = "-m 384 -n 384 -k 64 -gemmO 64" + +[[suite.test]] +config = "-m 64 -n 64 -k 64 -gemmO 64 -perf_config attn:v1:32,32,64,32,32,32,4,1" + +## This one test kPerBlock (16 x 4) == head_dim case +[[suite.test]] +config = "-m 64 -n 64 -k 64 -gemmO 64 -perf_config attn:v1:64,64,64,16,32,32,4,1" + +[[suite.test]] +config = "-m 64 -n 64 -k 64 -gemmO 64" + diff --git a/mlir/test/rocmlir-gen/attention-kernel.mlir b/mlir/test/rocmlir-gen/attention-kernel.mlir index 0727e449a173..e22e51669845 100644 --- a/mlir/test/rocmlir-gen/attention-kernel.mlir +++ b/mlir/test/rocmlir-gen/attention-kernel.mlir @@ -1,5 +1,4 @@ // RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 --with-attn-scale -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_SCALE -// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_NO_SCALE // CHECK_SCALE: module attributes {mhal.arch = "[[$ARCH:.*]]"} diff --git a/mlir/test/rocmlir-gen/gemm-elementwise-gemm-kernel.mlir b/mlir/test/rocmlir-gen/gemm-elementwise-gemm-kernel.mlir new file mode 100644 index 000000000000..60388cc5c716 --- /dev/null +++ b/mlir/test/rocmlir-gen/gemm-elementwise-gemm-kernel.mlir @@ -0,0 +1,24 @@ +// RUN: rocmlir-gen --arch gfx942:sramecc+:xnack- --operation gemm_gemm -m 1024 -n 1024 -k 32 -gemmO 32 -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope + +// CHECK: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK-LABEL: func.func @rock_gemm_gemm +// CHECK-SAME: (%[[aRaw:.*0]]: memref<32768xf32>, +// CHECK-SAME: %[[bRaw:.*1]]: memref<32768xf32>, +// CHECK-SAME: %[[cRaw:.*2]]: memref<32768xf32>, +// CHECK-SAME: %[[outputRaw:.*3]]: memref<32768xf32>) +// CHECK-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} +// CHECK-NEXT: %[[a:.*]] = rock.transform %[[aRaw]] {{.*}} : memref<32768xf32> to memref<1x1024x32xf32> +// CHECK-NEXT: %[[b:.*]] = rock.transform %[[bRaw]] {{.*}} : memref<32768xf32> to memref<1x32x1024xf32> +// CHECK-NEXT: %[[c:.*]] = rock.transform %[[cRaw]] {{.*}} : memref<32768xf32> to memref<1x1024x32xf32> +// CHECK-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<32768xf32> to memref<1x1024x32xf32> + +// CHECK-NEXT: rock.gemm_elementwise_gemm +// CHECK-NEXT: ab = %[[a]] * %[[b]] +// CHECK: %[[output]] = ab * %[[c]] +// CHECK: return + +// CHECK-LABEL: func.func @host_naive_gemm_gemm +// CHECK: %[[abTensor:.*]] = tosa.matmul %[[aTensor:.*]], %[[bTensor:.*]], %{{.*}}, %{{.*}} : ([[aShape:tensor<.*>]], [[bShape:tensor<.*>]], tensor<1xf32>, tensor<1xf32>) -> [[squareShape:tensor<.*>]] +// CHECK-DAG: %[[resultTensor:.*]] = tosa.matmul %[[abTensor]], %[[cTensor:.*]], %{{.*}}, %{{.*}} : ([[squareShape]], [[cShape:tensor<.*>]], tensor<1xf32>, tensor<1xf32>) -> [[cShape]] +// CHECK: return diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index 12d9a7e87d3d..bf05e2dc06df 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -99,7 +99,9 @@ static llvm::cl::opt operation( "Backpropogate convolution weights"), clEnumValN(rock::KernelType::Gemm, "gemm", "Matrix multiplication"), clEnumValN(rock::KernelType::Attention, "attention", - "Attention operation of transformer models")), + "Attention operation of transformer models"), + clEnumValN(rock::KernelType::GemmElementwiseGemm, "gemm_gemm", + "gemm+elementwise+gemm operation")), llvm::cl::value_desc("kernel type"), llvm::cl::init(rock::KernelType::Conv)); @@ -309,6 +311,11 @@ static llvm::cl::opt gemmN("n", llvm::cl::value_desc("positive integer"), llvm::cl::init(-1)); +/// gemm+elementwise+gemm options +static llvm::cl::opt + gemmO("gemmO", llvm::cl::desc("N dimension of the second gemm()"), + llvm::cl::value_desc("positive integer"), llvm::cl::init(-1)); + static llvm::cl::opt transposeA("transA", llvm::cl::desc("whether matrix A is GxMxK (default) or GxKxM"), @@ -1065,7 +1072,9 @@ static void verifyConvLayout() { static void populateDefaults() { const bool isGemm = operation == rock::KernelType::Gemm; const bool isAttention = operation == rock::KernelType::Attention; - const bool isConv = !(isGemm || isAttention); + const bool isGemmElntwiseGemm = + operation == rock::KernelType::GemmElementwiseGemm; + const bool isConv = !(isGemm || isAttention || isGemmElntwiseGemm); // Default f32 if we passed no `-t` arguments at all. if (outputDataType.empty()) { if (filterDataType != inputDataType) { @@ -1082,6 +1091,13 @@ static void populateDefaults() { gemmK = 769; gemmN = 512; } + if (isGemmElntwiseGemm) { + groupSize = 1; + gemmM = 1024; + gemmK = 769; + gemmN = 512; + gemmO = 769; + } if (isAttention) { groupSize = 1; sequenceLengthQ = 1024; @@ -1170,6 +1186,11 @@ auto getRequiredArgs(std::optional kernelType) { &gemmK, &gemmN}; return requiredGemmArgs; } + case rock::KernelType::GemmElementwiseGemm: { + const static RequiredArgsType requiredGemmElntwiseGemmArgs = { + &groupSize, &gemmM, &gemmK, &gemmN, &gemmO}; + return requiredGemmElntwiseGemmArgs; + } case rock::KernelType::Attention: { const static RequiredArgsType requiredAttenArgs = { &groupSize, &sequenceLengthQ, &sequenceLengthK, &headDimQK, &headDimV}; @@ -1193,9 +1214,11 @@ static LogicalResult detectMissingArguments() { } } - if (operation == rock::KernelType::Attention) { + if (operation == rock::KernelType::Attention || + operation == rock::KernelType::GemmElementwiseGemm) { if (dataTypeAlias.getValue().empty()) { - llvm::errs() << "Type of the Attention operation is not specified\n"; + llvm::errs() + << "Type of the Attention/gemm+gemm operation is not specified\n"; return failure(); } } @@ -2139,11 +2162,9 @@ createCPUConvFunc(ModuleOp module, static void getGemmTypes(ArrayRef elemTypes, SmallVectorImpl &result, bool isCpuVerifier) { Type cElemType = elemTypes[2]; - if (elemTypes[0].isInteger(8)) { - // Verify in int64_t to detect overflow - if (isCpuVerifier) - cElemType = IntegerType::get(cElemType.getContext(), 64); - } + // Verify in int64_t to detect overflow + if (elemTypes[0].isInteger(8) && isCpuVerifier) + cElemType = IntegerType::get(cElemType.getContext(), 64); SmallVector aDims = {groupSize, transposeA ? gemmK : gemmM, transposeA ? gemmM : gemmK}, @@ -2341,6 +2362,51 @@ getAttentionDimNames(SmallVectorImpl> &result, result.emplace_back(SmallVector{gName, seqQName, headVName}); } +static void getGemmElentwiseGemmTypes(SmallVectorImpl &result, + ArrayRef elemTypes) { + SmallVector aDims = {groupSize, transposeA ? gemmK : gemmM, + transposeA ? gemmM : gemmK}, + bDims = {groupSize, transposeB ? gemmN : gemmK, + transposeB ? gemmK : gemmN}, + cDims = {groupSize, transposeC ? gemmO : gemmN, + transposeC ? gemmN : gemmO}, + outDims = {groupSize, transposeO ? gemmO : gemmM, + transposeO ? gemmM : gemmO}; + + MemRefType aType = MemRefType::get(aDims, elemTypes[0]), + bType = MemRefType::get(bDims, elemTypes[1]), + cType = MemRefType::get(cDims, elemTypes[2]), + outType = MemRefType::get(outDims, elemTypes[3]); + result.push_back(aType); + result.push_back(bType); + result.push_back(cType); + result.push_back(outType); +} + +static void +getGemmElentwiseGemmDimNames(SmallVectorImpl> &result, + ArrayRef elementTypes) { + result.reserve(elementTypes.size()); + constexpr StringLiteral gName = "g", m = "m", n = "n", k = "k", + gemmO = "gemmO"; + if (transposeA) + result.emplace_back(SmallVector{gName, k, m}); + else + result.emplace_back(SmallVector{gName, m, k}); + if (transposeB) + result.emplace_back(SmallVector{gName, n, k}); + else + result.emplace_back(SmallVector{gName, k, n}); + if (transposeC) + result.emplace_back(SmallVector{gName, gemmO, n}); + else + result.emplace_back(SmallVector{gName, n, gemmO}); + if (transposeO) + result.emplace_back(SmallVector{gName, gemmO, m}); + else + result.emplace_back(SmallVector{gName, m, gemmO}); +} + template static TosaOp createOpAndInfer(OpBuilder &builder, Location loc, Type elemType, Args &&...args) { @@ -2359,8 +2425,9 @@ static TosaOp createOpAndInfer(OpBuilder &builder, Location loc, Type elemType, return op; } -Value addTensorArgToBlock(OpBuilder &builder, Location loc, - Block *preSoftmaxElemwiseBlock, Value funcArg) { +static Value addTensorArgToBlock(OpBuilder &builder, Location loc, + Block *preSoftmaxElemwiseBlock, + Value funcArg) { ShapedType funcArgType = cast(funcArg.getType()); Value funcArgMemRef = preSoftmaxElemwiseBlock->addArgument( MemRefType::get(funcArgType.getShape(), funcArgType.getElementType()), @@ -2696,6 +2763,87 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, return func; } +static func::FuncOp createGpuGemmElentwiseGemmKernel(ModuleOp module, + const GenParams ¶ms) { + MLIRContext *ctx = module.getContext(); + Location loc = module->getLoc(); + OpBuilder builder(ctx); + + // Set mhal.arch on module to make compilation pipeline work + StringAttr archAttr = builder.getStringAttr(params.arch); + if (!module->hasAttr("mhal.arch")) + module->setAttr("mhal.arch", archAttr); + + SmallVector argTypes; + getGemmElentwiseGemmTypes(argTypes, params.types); + SmallVector flatArgTypes = + llvm::map_to_vector(argTypes, rock::getFlattenedType); + + SmallVector funcAttrs = { + builder.getNamedAttr("kernel", builder.getUnitAttr()), + builder.getNamedAttr("mhal.arch", archAttr)}; + + constexpr StringLiteral kernelName("rock_gemm_gemm"); + auto func = builder.create( + loc, kernelName, builder.getFunctionType(flatArgTypes, {}), funcAttrs); + if (reverse_grid) { + func->setAttr(rock::ReverseGridAttrAttr::getMnemonic(), + builder.getUnitAttr()); + } + + Block *block = func.addEntryBlock(); + builder.setInsertionPointToStart(block); + + SmallVector unflattenedArgs; + SmallVector> allNames; + getGemmElentwiseGemmDimNames(allNames, params.types); + rock::expandFlatFunctionArguments(builder, func, allNames, argTypes, + unflattenedArgs); + + Value a = unflattenedArgs[0]; + Value b = unflattenedArgs[1]; + Value c = unflattenedArgs[2]; + Value output = unflattenedArgs[3]; + SmallVector elemwiseInputs; + + IntegerAttr numCUAttr = + (num_cu.getNumOccurrences() > 0 ? builder.getI32IntegerAttr(num_cu) + : nullptr); + auto gemmElntGemm = builder.create( + loc, TypeRange{}, a, b, c, elemwiseInputs, output, transposeA, transposeB, + transposeC, transposeO, archAttr, params.features, numCUAttr, + /*params0=*/nullptr, /*params1=*/nullptr, /*firstGemmIdx=*/0); + { + Block *preSecondGemmBlock = + &gemmElntGemm.getPreSecondGemmBody().emplaceBlock(); + PatternRewriter::InsertionGuard guard(builder); + builder.setInsertionPointToStart(preSecondGemmBlock); + ShapedType aType = cast(a.getType()); + ArrayRef aShape = aType.getShape(); + Type abElemType = aType.getElementType(); + MemRefType abMemRefType = + MemRefType::get({aShape[0], gemmM, gemmN}, abElemType); + Value abMemRef = preSecondGemmBlock->addArgument(abMemRefType, loc); + Value abTensor = rock::getAsTensor(builder, loc, abMemRef); + MemRefType resMemRefType = + MemRefType::get({aShape[0], gemmM, gemmN}, + cast(abTensor.getType()).getElementType()); + Value resMemref = + builder.create(loc, resMemRefType, abTensor); + Value outMemref = preSecondGemmBlock->addArgument(resMemRefType, loc); + builder.create(loc, resMemref, outMemref); + builder.create(loc); + } + + if (!params.perfConfig.empty()) + gemmElntGemm->setAttr("perf_config", + builder.getStringAttr(params.perfConfig)); + + builder.create(loc); + module.push_back(func); + return func; +} + static func::FuncOp createCpuGemmKernelWithMlir(ModuleOp module, const GenParams ¶ms) { MLIRContext *ctx = module.getContext(); @@ -2808,6 +2956,95 @@ static Type getAccType(Type inputType, OpBuilder builder) { return accType; } +static func::FuncOp +createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module, + const GenParams ¶ms) { + MLIRContext *ctx = module.getContext(); + OpBuilder builder(ctx); + Location loc = module->getLoc(); + + SmallVector argTypes; + getGemmElentwiseGemmTypes(argTypes, params.types); + SmallVector flatArgTypes = + llvm::map_to_vector(argTypes, rock::getFlattenedType); + + constexpr llvm::StringLiteral cpuKernName("host_naive_gemm_gemm"); + auto func = builder.create( + loc, cpuKernName, builder.getFunctionType(flatArgTypes, {})); + + Block *block = func.addEntryBlock(); + builder.setInsertionPointToStart(block); + + auto getTensorForBlockArg = [&builder, &loc, &block, + &argTypes](unsigned blockArgIndex, + bool isWritable = false) { + constexpr bool isRestrict{true}; + Value flatTensor = builder.create( + loc, block->getArgument(blockArgIndex), isRestrict, isWritable); + ArrayRef origShape = + cast(argTypes[blockArgIndex]).getShape(); + + Value reshapedTensor; + ImplicitLocOpBuilder implicitBuilder(loc, builder); + if (origShape.size() == 2) { + SmallVector expShape(origShape.size() + 1, 0); + expShape[0] = 1; + llvm::copy(origShape, expShape.begin() + 1); + auto shapeValue = tosa::getTosaConstShape(implicitBuilder, expShape); + reshapedTensor = + builder.create(loc, flatTensor, shapeValue); + } else { + auto shapeValue = tosa::getTosaConstShape(implicitBuilder, origShape); + reshapedTensor = + builder.create(loc, flatTensor, shapeValue); + } + return reshapedTensor; + }; + + auto aTensor = getTensorForBlockArg(0); + if (transposeA) { + aTensor = transposeMatrix(builder, loc, aTensor, {0, 2, 1}); + } + auto bTensor = getTensorForBlockArg(1); + if (transposeB) { + bTensor = transposeMatrix(builder, loc, bTensor, {0, 2, 1}); + } + auto cTensor = getTensorForBlockArg(2); + if (transposeC) { + cTensor = transposeMatrix(builder, loc, cTensor, {0, 2, 1}); + } + + Type firstGemmOutElemType = params.types[2]; + Value abTensor = createOpAndInfer( + builder, loc, firstGemmOutElemType, aTensor, bTensor); + + Type secondGemmOutElemType = params.types[3]; + Value resultTensor = createOpAndInfer( + builder, loc, secondGemmOutElemType, abTensor, cTensor); + + if (transposeO) { + resultTensor = transposeMatrix(builder, loc, resultTensor, {0, 2, 1}); + } + + Value output = block->getArguments().back(); + auto outputType = cast(output.getType()); + + ImplicitLocOpBuilder implicitBuilder(loc, builder); + auto shapeValue = + tosa::getTosaConstShape(implicitBuilder, outputType.getShape()); + auto flatResultTensor = + builder.create(loc, resultTensor, shapeValue); + + auto flatResultMemref = builder.create( + loc, outputType, flatResultTensor); + + builder.create(loc, flatResultMemref, output); + + builder.create(loc); + module.push_back(func); + return func; +} + static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, const GenParams ¶ms) { MLIRContext *ctx = module.getContext(); @@ -2992,9 +3229,6 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, Value softmaxTensor = createOpAndInfer( builder, loc, cast(expsSums.getType()).getElementType(), expsTensor, invExpsSums, /*shift=*/constZero); -#ifdef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX - softmaxTensor = qkTensor; -#endif auto resultOutElementType = cast(softmaxTensor.getType()).getElementType(); auto softmaxZp = @@ -3498,6 +3732,15 @@ static void insertValidationCalls(const GenParams &genParams, OpBuilder &b, auto cpuAttentionFunc = createCpuAttentionKernelWithMlir(module, genParams); b.create(loc, cpuAttentionFunc, valVars); + } else if (genParams.operation == rock::KernelType::GemmElementwiseGemm) { + if (validationType == "cpp") { + llvm::errs() + << "External gemm elementwise gemm validator is not available\n"; + exit(1); + } + auto cpuGemmElementwiseGemmFunc = + createCpuGemmElementwiseGemmKernelWithMlir(module, genParams); + b.create(loc, cpuGemmElementwiseGemmFunc, valVars); } else { llvm::errs() << "Validation generation requested, but no operation specified\n"; @@ -3602,6 +3845,9 @@ static LogicalResult populateHostHarnessLogic( case rock::KernelType::ConvBwdWeight: outIndices.push_back(0); break; + case rock::KernelType::GemmElementwiseGemm: + outIndices.push_back(3); + break; case rock::KernelType::Attention: isAttention = true; int32_t optionalArgsCounter{3}; @@ -3845,11 +4091,14 @@ static void generateKernel(MLIRContext *context, GenParams &genParams, const bool isGemm = operation == rock::KernelType::Gemm; const bool isAttention = operation == rock::KernelType::Attention; - const bool isConv = !(isGemm || isAttention); + const bool isGemmElntwiseGemm = + operation == rock::KernelType::GemmElementwiseGemm; + const bool isConv = !(isGemm || isAttention || isGemmElntwiseGemm); auto convConfigStr = populateConvConfig.getValue(); if (!convConfigStr.empty() && !isConv) { - llvm::errs() << "Cannot use --conv-config with gemm/attention operations\n"; + llvm::errs() << "Cannot use --conv-config with gemm/attention/gemm+gemm " + "operations\n"; exit(1); } @@ -3950,6 +4199,16 @@ static void generateKernel(MLIRContext *context, GenParams &genParams, genParams.types.push_back(typeFromString(arg, context)); genParams.convConfig = std::nullopt; (void)createGpuGemmKernel(module, genParams); + } else if (isGemmElntwiseGemm) { + constexpr size_t numArgs{4}; + // Note: In the current implementation, all operands have the same type. + // This behaviour enforced by `-t`. See, detectMissingArguments() + auto elemType = typeFromString(inputDataType.getValue(), context); + for (size_t argIdx{0}; argIdx < numArgs; ++argIdx) { + genParams.types.push_back(elemType); + } + genParams.convConfig = std::nullopt; + (void)createGpuGemmElentwiseGemmKernel(module, genParams); } else if (isAttention) { auto elemType = typeFromString(inputDataType.getValue(), context); // We only support first-gemm i8 version of attention diff --git a/mlir/tools/rocmlir-tuning-driver/rocmlir-tuning-driver.cpp b/mlir/tools/rocmlir-tuning-driver/rocmlir-tuning-driver.cpp index d4473899924c..f6d2a75d1e7e 100644 --- a/mlir/tools/rocmlir-tuning-driver/rocmlir-tuning-driver.cpp +++ b/mlir/tools/rocmlir-tuning-driver/rocmlir-tuning-driver.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h" #include "mlir/Dialect/Rock/Pipelines/Pipelines.h" #include "mlir/Dialect/Rock/Tuning/RockTuning.h" #include "mlir/Dialect/Rock/utility/fusionUtils.h" @@ -210,11 +211,10 @@ extractKernelDataType(ModuleOp op, SmallVectorImpl &kernels) { toTuneType = gemmLike.getAType(); outputType = gemmLike.getCType(); }); - } - if (!toTuneType) { - f.walk([&toTuneType, &outputType](rock::AttentionOp attnOp) { - toTuneType = attnOp.getQueries().getType().getElementType(); - outputType = toTuneType; + f.walk([&toTuneType, + &outputType](rock::RockGemmGemmWrapperInterface attnLike) { + toTuneType = cast(attnLike.getAType()).getElementType(); + outputType = cast(attnLike.getOutType()).getElementType(); }); } }); @@ -328,7 +328,7 @@ static LogicalResult runTuningLoop(ModuleOp source) { tuneCopy->walk([&perfConfigAttr](rock::RockGemmWrapperInterface op) { op->setAttr("perf_config", perfConfigAttr); }); - tuneCopy->walk([&perfConfigAttr](rock::AttentionOp op) { + tuneCopy->walk([&perfConfigAttr](rock::RockGemmGemmWrapperInterface op) { op->setAttr("perf_config", perfConfigAttr); }); diff --git a/mlir/utils/performance/perfCommonUtils.py b/mlir/utils/performance/perfCommonUtils.py index 080310c558bb..a1a839b9c1a2 100644 --- a/mlir/utils/performance/perfCommonUtils.py +++ b/mlir/utils/performance/perfCommonUtils.py @@ -6,6 +6,7 @@ class Operation(enum.IntEnum): GEMM = 2 FUSION = 3 ATTENTION = 4 + GEMM_GEMM = 5 @staticmethod def fromName(name: str) -> 'self': @@ -16,6 +17,8 @@ def fromName(name: str) -> 'self': return Operation.GEMM elif name == 'attention': return Operation.ATTENTION + elif name == 'gemm_gemm': + return Operation.GEMM_GEMM elif name == 'fusion': return Operation.FUSION else: diff --git a/mlir/utils/performance/perfRunner.py b/mlir/utils/performance/perfRunner.py index 82a8c4ccc727..dffa3ae01873 100644 --- a/mlir/utils/performance/perfRunner.py +++ b/mlir/utils/performance/perfRunner.py @@ -35,6 +35,7 @@ DATA_TYPES_GEMM = ['f32', 'f16', 'bf16', 'i8', 'fp8'] DATA_TYPES_ATTENTION = ['f32', 'f16', 'bf16'] +DATA_TYPES_GEMM_GEMM = ['f32'] OUTPUT_DATA_TYPES_MAP = {'f32': 'f32', 'f16': 'f16', 'bf16': 'bf16', 'i8': 'i32', 'fp8':'f32', 'fp8_fp8': 'f32', 'fp8_bf8': 'f32', 'bf8_fp8': 'f32', 'bf8_bf8': 'f32'} @@ -608,6 +609,39 @@ def getGemmConfigurations(fileName, dataTypes=DATA_TYPES_GEMM, outDataTypeMap=OU configs.append(oneConfig) return configs +def getGemmGemmConfigurations(fileName): + bool_space = ['false', 'true'] + default_test_space = { + "-t": DATA_TYPES_GEMM_GEMM, + "-transA": bool_space, + "-transB": bool_space, + "-transC": bool_space, + "-transO": bool_space, + } + configs = [] + if fileName: + with open(fileName, 'r') as configFile: + lines = configFile.readlines() + for line in lines: + line = line.strip() + # Skip empty lines + if len(line) == 0 or line[0] == '#': + continue + test_space = [] + args = [] + for arg in default_test_space.keys(): + if arg not in line: + test_space.append(default_test_space[arg]) + args.append(arg) + for test_vector in itertools.product(*test_space): + # Strip to avoid spurious spaces + oneConfig = line.strip() + for arg, value in zip(args, test_vector): + oneConfig = f"{arg} {value} {oneConfig}" + if oneConfig not in configs: + configs.append(oneConfig) + return configs + def getAttentionConfigurations(fileName): bool_space = ['false', 'true'] default_test_space = { @@ -699,6 +733,7 @@ def fromCommandLine(cls, argv, arch, numCU): n = None transA = None transB = None + outDataType = None perf_config = '' for i in range(0, len(argv), 2): opt = argv[i] @@ -718,7 +753,7 @@ def fromCommandLine(cls, argv, arch, numCU): elif opt.endswith("-transB"): transB = (val.lower() in ["1", "true"]) elif opt.endswith("-out_datatype"): - outDataType =val.lower() + outDataType = val.lower() elif opt.endswith("-perf_config"): perf_config = val else: @@ -736,8 +771,9 @@ def toCommandLine(self): def __init__(self, dtype: str, outDataType: str, g: int, m: int, k: int, n: int, transA: bool, transB: bool, arch: str, numCU: int, perf_config: str = ''): - if dtype not in {"f16", "f32", "bf16", "i8", "fp8"}: + if dtype not in DATA_TYPES_GEMM: raise ValueError(f"Invalid datatype: {dtype}") + self.dataType = dtype self.outDataType = outDataType self.g = g @@ -752,12 +788,151 @@ def __init__(self, dtype: str, outDataType: str, g: int, m: int, k: int, n: int, self.chip = GFX_CHIP_RE.search(arch).group(0) self.numCU = numCU +class GemmGemmConfiguration(PerfConfiguration): + TABLE_COLUMNS = reportUtils.GEMM_GEMM_TEST_PARAMETERS + ['TFlops'] + def __init__(self, dtype: str, g: int, m: int, k: int, n: int, o: int, + transA: bool, transB: bool, transC: bool, transO: bool, arch: str, numCU: int, perf_config: str = ''): + if dtype not in DATA_TYPES_GEMM_GEMM: + raise ValueError(f"Invalid datatype for a: {dtype}") + + self.dataType = dtype + self.g = g + self.m = m + self.k = k + self.n = n + self.o = o + self.transA = transA + self.transB = transB + self.transC = transC + self.transO = transO + + self.arch = arch + self.chip = GFX_CHIP_RE.search(arch).group(0) + self.numCU = numCU + self.perfConfig = perf_config + + def computeTFlops(self, ns): + # NaN will propagate as expected + # Repeats are handled by the fact that we're using avarageNs + first_matmul_flops = 2.0 * self.g * self.m * self.k * self.n + second_matmul_flops = 2.0 * self.g * self.m * self.n * self.o + total_flops = first_matmul_flops + second_matmul_flops + + return total_flops / (float(ns) * 1e-9) / 1e12 + + def tableEntry(self, nanoSeconds): + result = {} + values = [ + self.dataType, + self.chip, + self.numCU, + self.transA, + self.transB, + self.transC, + self.transO, + self.g, + self.m, + self.k, + self.n, + self.o, + self.perfConfig, + self.computeTFlops(nanoSeconds) + ] + assert(len(self.TABLE_COLUMNS) == len(values)) + for k, v in zip(self.TABLE_COLUMNS, values): + result[k] = v + return result + + def __repr__(self): + attrs = ', '.join(f"{key}={repr(value)!r}" for key, value in self.__dict__.items()) + return f"{self.__class__.__name__}({attrs})" + + def setPerfConfig(self, perf_config): + self.perfConfig = perf_config + + def generateMlirDriverCommandLine(self, rocmlir_gen_flags): + result = ' '.join(['-operation', 'gemm_gemm', + '-t', self.dataType, + '--arch', self.arch, + '--num_cu', str(self.numCU), + '-g', str(self.g), + '-m', str(self.m), + '-k', str(self.k), + '-n', str(self.n), + '-gemmO', str(self.o), + f"-transA={self.transA}", + f"-transB={self.transB}", + f"-transC={self.transC}", + f"-transO={self.transO}", + '--kernel-repeats', str(MLIR_N_REPEATS), + f"--perf_config={self.perfConfig}"]) + result += ' ' + if rocmlir_gen_flags != '': + result += ' '.join(rocmlir_gen_flags.split()) + return result + + @classmethod + def fromCommandLine(cls, argv, arch, numCU): + # optional defaults + perf_config = '' + dtype = None + g = None + m = None + k = None + n = None + o = None + transA = False + transB = False + transC = False + transO = False + # Please keep this in sync with mlir::rock::getTuningProblemStr() + for i in range(0, len(argv), 2): + opt = argv[i] + val = argv[i + 1] + if opt.endswith("-t"): + dtype = val + elif opt.endswith("-g"): + g = int(val) + elif opt.endswith("-m"): + m = int(val) + elif opt.endswith("-k"): + k = int(val) + elif opt.endswith("-n"): + n = int(val) + elif opt.endswith("-gemmO"): + o = int(val) + elif opt.endswith("-transA"): + transA = (val.lower() in ["1", "true"]) + elif opt.endswith("-transB"): + transB = (val.lower() in ["1", "true"]) + elif opt.endswith("-transC"): + transC = (val.lower() in ["1", "true"]) + elif opt.endswith("-transO"): + transO = (val.lower() in ["1", "true"]) + elif opt.endswith("-perf_config"): + perf_config = val + else: + raise ValueError(f"Unknown GEMM+GEMM config argument {opt} -> {val}") + for v in [dtype, g, m, k, n, o, transA, transB, transC, transO]: + if v is None: + raise ValueError("Incomplete GEMM+GEMM configuration") + + return cls(dtype, g, m, k, n, o, transA, transB, transC, transO, arch, numCU, perf_config) + + def toCommandLine(self): + return (f"-t {self.dataType} " + + f"-transA {str(self.transA).lower()} -transB {str(self.transB).lower()} " + + f"-transC {str(self.transC).lower()} -transO {str(self.transO).lower()} " + + f"-g {self.g} " + + f"-m {str(self.m)} -k {str(self.k)} -n {str(self.n)} -gemmO {str(self.o)}") + class AttentionConfiguration(PerfConfiguration): TABLE_COLUMNS = reportUtils.ATTN_TEST_PARAMETERS + ['TFlops'] def __init__(self, dtype: str, g: int, seq_len_q: int, seq_len_k: int, head_dim_qk: int, head_dim_v: int, with_attn_scale: int, transQ: bool, transK: bool, transV: bool, transO: bool, arch: str, numCU: int, perf_config: str = ''): - if dtype not in {"f16", "f32"}: - raise ValueError(f"Invalid datatype: {dtype}") + if dtype not in DATA_TYPES_ATTENTION: + raise ValueError(f"Invalid datatype for a: {dtype}") + self.dataType = dtype self.g = g self.seq_len_q = seq_len_q @@ -852,6 +1027,12 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags): def fromCommandLine(cls, argv, arch, numCU): # optional defaults perf_config = '' + dtype = None + g = None + seq_len_q = None + seq_len_k = None + head_dim_qk = None + head_dim_v = None transQ = False transK = False transV = False @@ -1364,7 +1545,7 @@ def main(args=None): allow_abbrev=False, ) - parser.add_argument("--op", "--operation", choices=['conv', 'gemm', 'fusion', 'attention'], + parser.add_argument("--op", "--operation", choices=['conv', 'gemm', 'fusion', 'attention', 'gemm_gemm'], default='conv', help="Operation to benchmark") @@ -1495,6 +1676,9 @@ def main(args=None): elif opType == Operation.ATTENTION: confClass = AttentionConfiguration externalLib = None + elif opType == Operation.GEMM_GEMM: + confClass = GemmGemmConfiguration + externalLib = None configs_path = None if parsed_args.config else parsed_args.configs_file paths = create_paths(configs_path, parsed_args.mlir_build_dir) @@ -1506,6 +1690,8 @@ def main(args=None): configs = getGemmConfigurations(paths.configuration_file_path, datatypes, outputTypeMap) elif opType == Operation.ATTENTION: configs = getAttentionConfigurations(paths.configuration_file_path) + elif opType == Operation.GEMM_GEMM: + configs = getGemmGemmConfigurations(paths.configuration_file_path) if parsed_args.external or parsed_args.batch_external or parsed_args.batch_all: if not foundExternalTool(paths, opType, externalLib): diff --git a/mlir/utils/performance/reportUtils.py b/mlir/utils/performance/reportUtils.py index 5e00711a7587..85d65ff04fba 100644 --- a/mlir/utils/performance/reportUtils.py +++ b/mlir/utils/performance/reportUtils.py @@ -24,6 +24,7 @@ 'PaddingH', 'PaddingW', 'PerfConfig'] GEMM_TEST_PARAMETERS = ['DataType', 'OutDataType', 'Chip', 'numCU', 'TransA', 'TransB', 'G', 'M', 'K', 'N', 'PerfConfig'] ATTN_TEST_PARAMETERS = ['DataType', 'Chip', 'numCU', 'TransQ', 'TransK', 'TransV', 'TransO', 'WithAttnScale', 'G', 'SeqLenQ', 'SeqLenK', 'HeadDimQK', 'HeadDimV', 'PerfConfig'] +GEMM_GEMM_TEST_PARAMETERS = ['DataType', 'Chip', 'numCU', 'TransA', 'TransB', 'TransC', 'TransO', 'G', 'M', 'K', 'N', 'O', 'PerfConfig'] ROUND_DIGITS = 2 def geoMean(data): diff --git a/mlir/utils/performance/tuningRunner.py b/mlir/utils/performance/tuningRunner.py index 686e79ee3aa8..09e6ae178585 100755 --- a/mlir/utils/performance/tuningRunner.py +++ b/mlir/utils/performance/tuningRunner.py @@ -19,6 +19,7 @@ from perfRunner import ConvConfiguration from perfRunner import GemmConfiguration from perfRunner import AttentionConfiguration +from perfRunner import GemmGemmConfiguration from perfRunner import Paths from perfRunner import getChip from perfCommonUtils import CORRECT_RESULT_RE @@ -256,7 +257,7 @@ def main(args=None): allow_abbrev=False, ) - parser.add_argument("--op", "--operation", choices=['conv', 'gemm', 'fusion', 'attention'], + parser.add_argument("--op", "--operation", choices=['conv', 'gemm', 'fusion', 'attention', 'gemm_gemm'], default='conv', help="Operation for tuning") @@ -381,6 +382,8 @@ def main(args=None): confClass = GemmConfiguration elif opType == Operation.ATTENTION: confClass = AttentionConfiguration + elif opType == Operation.GEMM_GEMM: + confClass = GemmGemmConfiguration else: raise RuntimeError("Tuning operation was not provided/found") @@ -393,6 +396,8 @@ def main(args=None): configs = perfRunner.getGemmConfigurations(paths.configuration_file_path, datatypes, outputMap) elif opType == Operation.ATTENTION: configs = perfRunner.getAttentionConfigurations(paths.configuration_file_path) + elif opType == Operation.GEMM_GEMM: + configs = perfRunner.getGemmGemmConfigurations(paths.configuration_file_path) winners, allData = tuneMLIRKernels(configs, confClass, paths, options)