Skip to content

Commit 54237b0

Browse files
committed
Add Gemm+Elementwise+Gemm support
1 parent b6c726a commit 54237b0

25 files changed

+1371
-486
lines changed

mlir/include/mlir/Dialect/Rock/IR/Rock.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ class FusionRoot : public TraitBase<ConcreteType, FusionRoot> {};
4949
} // namespace OpTrait
5050
} // namespace mlir
5151

52-
// Following ifdef could be used to change
53-
// the attention operator to be a fused gemm-gemm
54-
// kernel for debugging purposes. This will also
55-
// adjust the test harness to verify the same as well
56-
// #define ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX
57-
5852
namespace mlir {
5953
namespace rock {
6054
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;
6363
def KernelTypeConvBwdWeight : I32EnumAttrCase<"ConvBwdWeight", 2>;
6464
def KernelTypeGemm : I32EnumAttrCase<"Gemm", 3>;
6565
def KernelTypeAttention : I32EnumAttrCase<"Attention", 4>;
66+
def KernelTypeGemmElementwiseGemm : I32EnumAttrCase<"GemmElementwiseGemm", 5>;
6667

67-
def KernelType : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
68-
[KernelTypeConv, KernelTypeConvBwdData,
69-
KernelTypeConvBwdWeight, KernelTypeGemm,
70-
KernelTypeAttention]>;
68+
def KernelType
69+
: Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
70+
[KernelTypeConv, KernelTypeConvBwdData,
71+
KernelTypeConvBwdWeight, KernelTypeGemm,
72+
KernelTypeAttention, KernelTypeGemmElementwiseGemm]>;
7173

7274
/// TransformType
7375
def PassThrough : I32EnumAttrCase<"PassThrough", 0>;

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -205,38 +205,33 @@ def Rock_ReduceOp :
205205
}];
206206
}
207207

208-
def Rock_AttentionOp :
209-
Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
210-
Arguments<(ins
211-
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212-
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213-
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214-
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215-
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216-
TensorOrMemRefOf<[F32, BF16, F16]>:$out,
217-
UnitAttr:$qTransposed,
218-
UnitAttr:$kTransposed,
219-
UnitAttr:$vTransposed,
220-
UnitAttr:$oTransposed,
221-
StrAttr:$arch,
222-
Rock_GemmFeaturesAttr:$features,
223-
OptionalAttr<I32Attr>:$numCU,
224-
OptionalAttr<RockTuningParamAttrInterface>:$params0,
225-
OptionalAttr<RockTuningParamAttrInterface>:$params1,
226-
I32Attr:$firstGemmIdx
227-
)>,
228-
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
208+
def Rock_AttentionOp
209+
: Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
210+
RockFusionRoot, AttrSizedOperandSegments]>,
211+
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212+
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213+
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215+
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216+
TensorOrMemRefOf<[F32, F16, BF16]>:$out, UnitAttr:$qTransposed,
217+
UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed,
218+
StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
219+
OptionalAttr<I32Attr>:$numCU,
220+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
221+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
222+
I32Attr:$firstGemmIdx)>,
223+
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
229224
let summary = "Attention operation of transformer models";
230225
let description = [{
231-
Performs the operation out = SOFTMAX((queries * keys) .* scale) * values.
226+
Performs the operation out = SOFTMAX(queries * keys) * values.
232227

233228
This operation performs attention mechanism of transformer models.
234229

235230
Those creating a `rock.attention` must specify the GPU architecture being targetted
236231
and the number of compute units (numCu) available. The parameters
237232
`gridSize`, and `blockSize` are optional as they can be inferred by
238233
a tuning process or a heuristic, but they must be set before the `attention` is
239-
lowered into the `gridwise_attention` stage of the code generation pipeline.
234+
lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
240235

241236
`features` specifies what hardware features can be used in the generated code.
242237
}];
@@ -255,6 +250,50 @@ def Rock_AttentionOp :
255250
}];
256251
}
257252

253+
def Rock_GemmElementwiseGemmOp
254+
: Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<
255+
MemoryEffectsOpInterface>,
256+
RockFusionRoot]>,
257+
AllElementTypesMatch<["a", "b", "c"]>,
258+
Arguments<(ins TensorOrMemRefOf<[F32]>:$a, TensorOrMemRefOf<[F32]>:$b,
259+
TensorOrMemRefOf<[F32]>:$c,
260+
Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
261+
TensorOrMemRefOf<[F32]>:$out, UnitAttr:$aTransposed,
262+
UnitAttr:$bTransposed, UnitAttr:$cTransposed, UnitAttr:$oTransposed,
263+
StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
264+
OptionalAttr<I32Attr>:$numCU,
265+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
266+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
267+
I32Attr:$firstGemmIdx)>,
268+
Results<(outs Optional<TensorOf<[F32]>>:$result)> {
269+
let summary = "GEMM-elementwise-GEMM operation";
270+
let description = [{
271+
Performs the operation out = (a * b) * c.
272+
273+
This operation performs fused GEMM-elementwise-GEMM.
274+
275+
Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
276+
and the number of compute units (numCu) available. The parameters
277+
`gridSize`, and `blockSize` are optional as they can be inferred by
278+
a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
279+
lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
280+
281+
`features` specifies what hardware features can be used in the generated code.
282+
}];
283+
let hasVerifier = 1;
284+
let regions = (region AnyRegion:$preSecondGemmBody);
285+
let assemblyFormat = [{
286+
`{` `\n`
287+
` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n`
288+
(`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)?
289+
(`tr` $oTransposed^)? $out `=` `ab` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n`
290+
`}` attr-dict (`->` type($result)^)?
291+
}];
292+
let extraClassDeclaration = [{
293+
::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); }
294+
}];
295+
}
296+
258297
def Rock_InitKernelOp :
259298
Rock_Op<"init_kernel", []>,
260299
Arguments<(ins AnyTensorOrMemRef:$buffer,
@@ -432,24 +471,23 @@ def Rock_GridwiseGemmAccelOp :
432471
}
433472

434473
// gridwise_attention_accel
435-
def Rock_GridwiseAttentionAccelOp :
436-
Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
437-
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
438-
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
439-
MemRefRankOf<[F32, F16, BF16,], [3]>:$values,
440-
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
441-
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
442-
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
443-
StrAttr:$arch,
444-
Rock_GemmFeaturesAttr:$features,
445-
I32Attr:$blockSize,
446-
I32Attr:$gridSize,
447-
UnitAttr:$disableQBypassLDS,
448-
OptionalAttr<IndexAttr>:$prePadG0M,
449-
OptionalAttr<IndexAttr>:$prePadG0N,
450-
RockAccelTuningParamAttrInterface:$params0,
451-
RockAccelTuningParamAttrInterface:$params1,
452-
I32Attr:$firstGemmIdx)> {
474+
def Rock_GridwiseAttentionAccelOp
475+
: Rock_Op<"gridwise_attention_accel",
476+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
477+
RockFusionRoot, AttrSizedOperandSegments]>,
478+
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
479+
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
480+
MemRefRankOf<[F32, F16, BF16, ], [3]>:$values,
481+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
482+
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
483+
MemRefRankOf<[F32, F16, BF16], [3]>:$out, StrAttr:$arch,
484+
Rock_GemmFeaturesAttr:$features, I32Attr:$blockSize,
485+
I32Attr:$gridSize, UnitAttr:$disableQBypassLDS,
486+
OptionalAttr<IndexAttr>:$prePadG0M,
487+
OptionalAttr<IndexAttr>:$prePadG0N,
488+
RockAccelTuningParamAttrInterface:$params0,
489+
RockAccelTuningParamAttrInterface:$params1, I32Attr:$firstGemmIdx,
490+
DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {
453491
let summary = "Gridwise attention accelerated version";
454492
let description = [{
455493
The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/IR/TypeRange.h"
3434
#include "mlir/IR/TypeUtilities.h"
3535
#include "mlir/IR/Value.h"
36+
#include "mlir/IR/ValueRange.h"
3637
#include "mlir/Parser/Parser.h"
3738
#include "mlir/Support/LLVM.h"
3839
#include "mlir/Support/LogicalResult.h"
@@ -46,6 +47,7 @@
4647
#include "llvm/ADT/StringRef.h"
4748
#include "llvm/ADT/TypeSwitch.h"
4849
#include "llvm/Support/Debug.h"
50+
#include "llvm/Support/LogicalResult.h"
4951
#include "llvm/Support/MathExtras.h"
5052
#include "llvm/Support/SMLoc.h"
5153
#include <algorithm>
@@ -486,10 +488,13 @@ ConvOpType mlir::rock::convOpTypeFromKernelType(KernelType kernelType) {
486488
return ConvOpType::BwdWeight;
487489
case KernelType::Gemm:
488490
llvm_unreachable(
489-
"Gemm ops shouldn't be in convolution-specific lowering passes");
491+
"GEMM ops shouldn't be in convolution-specific lowering passes");
490492
case KernelType::Attention:
491493
llvm_unreachable(
492494
"Attention ops shouldn't be in convolution-specific lowering passes");
495+
case KernelType::GemmElementwiseGemm:
496+
llvm_unreachable(
497+
"GEMM+GEMM ops shouldn't be in convolution-specific lowering passes");
493498
}
494499
llvm_unreachable("Unsuppported KernelType");
495500
}
@@ -566,17 +571,20 @@ static LogicalResult verifyGemmTypes(Operation *op, GemmFeatures features,
566571
"Mfma gridwise does not support E4M3/E5M2 data types ");
567572
}
568573
}
569-
if (isa<FloatType>(elemTypeA) && !isa<FloatType>(elemTypeC)) {
570-
return op->emitOpError("floating-point input type ")
571-
<< elemTypeA
572-
<< " requires a floating-point output type, but the output type is "
573-
<< elemTypeC;
574-
}
575-
if (isa<IntegerType>(elemTypeA) && !isa<IntegerType>(elemTypeC)) {
576-
return op->emitOpError("integer input type ")
577-
<< elemTypeA
578-
<< " requires an integer output type, but the output type is "
579-
<< elemTypeC;
574+
if (elemTypeC) {
575+
if (isa<FloatType>(elemTypeA) && !isa<FloatType>(elemTypeC)) {
576+
return op->emitOpError("floating-point input type ")
577+
<< elemTypeA
578+
<< " requires a floating-point output type, but the output type "
579+
"is "
580+
<< elemTypeC;
581+
}
582+
if (isa<IntegerType>(elemTypeA) && !isa<IntegerType>(elemTypeC)) {
583+
return op->emitOpError("integer input type ")
584+
<< elemTypeA
585+
<< " requires an integer output type, but the output type is "
586+
<< elemTypeC;
587+
}
580588
}
581589
return success();
582590
}
@@ -2068,77 +2076,107 @@ LogicalResult BlockwiseFillOp::verify() {
20682076
}
20692077

20702078
//===-----------------------------------------------------===//
2071-
// AttentionOp
2079+
// GemmElementwiseGemmOp
20722080
//===-----------------------------------------------------===//
20732081

2074-
LogicalResult AttentionOp::verify() {
2075-
ShapedType qType = getQueries().getType();
2082+
template <typename Op>
2083+
static LogicalResult verifyAttentionOp(Op op, Value q, Value k, Value v,
2084+
Value currentSeqLen, bool qTransposed,
2085+
bool kTransposed, bool vTransposed) {
2086+
ShapedType qType = cast<ShapedType>(q.getType());
20762087
int64_t qBatchDim = qType.getShape().size() == 3 ? qType.getShape()[0] : 1;
20772088
ArrayRef<int64_t> qLastDims = qType.getShape().slice(qType.getRank() - 2);
2078-
auto [queryM, queryK] = getQTransposed()
2079-
? std::tuple{qLastDims[1], qLastDims[0]}
2080-
: std::tuple{qLastDims[0], qLastDims[1]};
2089+
auto [queryM, queryK] = qTransposed ? std::tuple{qLastDims[1], qLastDims[0]}
2090+
: std::tuple{qLastDims[0], qLastDims[1]};
20812091

2082-
ShapedType kType = getKeys().getType();
2092+
ShapedType kType = cast<ShapedType>(k.getType());
20832093
int64_t kBatchDim = kType.getShape().size() == 3 ? kType.getShape()[0] : 1;
20842094
ArrayRef<int64_t> kLastDims = kType.getShape().slice(kType.getRank() - 2);
2085-
auto [keyK, keyN] = getKTransposed() ? std::tuple{kLastDims[1], kLastDims[0]}
2086-
: std::tuple{kLastDims[0], kLastDims[1]};
2095+
auto [keyK, keyN] = kTransposed ? std::tuple{kLastDims[1], kLastDims[0]}
2096+
: std::tuple{kLastDims[0], kLastDims[1]};
20872097

2088-
ShapedType vType = getValues().getType();
2098+
ShapedType vType = cast<ShapedType>(v.getType());
20892099
int64_t vBatchDim = vType.getShape().size() == 3 ? vType.getShape()[0] : 1;
20902100
ArrayRef<int64_t> vLastDims = vType.getShape().slice(vType.getRank() - 2);
2091-
auto [valueK, valueN] = getVTransposed()
2092-
? std::tuple{vLastDims[1], vLastDims[0]}
2093-
: std::tuple{vLastDims[0], vLastDims[1]};
2101+
auto [valueK, valueN] = vTransposed ? std::tuple{vLastDims[1], vLastDims[0]}
2102+
: std::tuple{vLastDims[0], vLastDims[1]};
20942103

20952104
if (qBatchDim != kBatchDim || kBatchDim != vBatchDim) {
2096-
return emitError("Batch dimensions do not match");
2105+
return op.emitError("Batch dimensions do not match");
20972106
}
20982107
if (queryK != keyK) {
2099-
return emitError("reduction dimensions of first gemm do not match");
2108+
return op.emitError("reduction dimensions of first gemm do not match");
21002109
}
21012110
if (keyN != valueK) {
2102-
return emitError("reduction dimensions of second gemm do not match");
2111+
return op.emitError("reduction dimensions of second gemm do not match");
21032112
}
21042113

21052114
// check output type
2106-
ShapedType oType = getOut().getType();
2115+
ShapedType oType = op.getOut().getType();
21072116
int64_t oBatchDim = oType.getShape().size() == 3 ? oType.getShape()[0] : 1;
21082117

21092118
ArrayRef<int64_t> oLastDims = oType.getShape().slice(oType.getRank() - 2);
21102119
auto [outputSeqLen, outputHeadDim] =
2111-
getOTransposed() ? std::tuple{oLastDims[1], oLastDims[0]}
2112-
: std::tuple{oLastDims[0], oLastDims[1]};
2120+
op.getOTransposed() ? std::tuple{oLastDims[1], oLastDims[0]}
2121+
: std::tuple{oLastDims[0], oLastDims[1]};
21132122

21142123
if (qType.getShape().size() != oType.getShape().size()) {
2115-
return emitError("Number of dimensions do not match (Q and Output)");
2124+
return op.emitError("Number of dimensions do not match (Q and Output)");
21162125
}
21172126
if (qBatchDim != oBatchDim) {
2118-
return emitError("Batch dimensions do not match (Q and Output)");
2127+
return op.emitError("Batch dimensions do not match (Q and Output)");
21192128
}
21202129
if (queryM != outputSeqLen) {
2121-
return emitError("Sequence length does not match (Q and Output)");
2130+
return op.emitError("Sequence length does not match (Q and Output)");
21222131
}
21232132
if (valueN != outputHeadDim) {
2124-
return emitError("Head dimensions do not match (V and Output)");
2133+
return op.emitError("Head dimensions do not match (V and Output)");
21252134
}
21262135

21272136
// check currentSeqLen (KV Cache)
2128-
auto currentSeqLen = getCurrentSeqLen();
21292137
if (currentSeqLen) {
2130-
ShapedType seqLenType = currentSeqLen.getType();
2138+
ShapedType seqLenType = cast<ShapedType>(currentSeqLen.getType());
21312139
if (seqLenType.getShape().size() != 1) {
2132-
return emitError("Number of dimensions is not one (currentSeqLen)");
2140+
return op.emitError("Number of dimensions is not one (currentSeqLen)");
21332141
}
21342142
if (seqLenType.getShape()[0] != oBatchDim) {
2135-
return emitError(
2143+
return op.emitError(
21362144
"Batch dimensions do not match (currentSeqLen and Output)");
21372145
}
21382146
}
21392147
return success();
21402148
}
21412149

2150+
LogicalResult GemmElementwiseGemmOp::verify() {
2151+
return verifyAttentionOp(*this, getA(), getB(), getC(),
2152+
/*currentSeqLen=*/nullptr, getATransposed(),
2153+
getBTransposed(), getCTransposed());
2154+
}
2155+
2156+
void GemmElementwiseGemmOp::getEffects(
2157+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2158+
auto *read = MemoryEffects::Read::get();
2159+
auto *write = MemoryEffects::Write::get();
2160+
effects.emplace_back(read, &getOutMutable());
2161+
effects.emplace_back(write, &getOutMutable());
2162+
2163+
effects.emplace_back(read, &getAMutable());
2164+
effects.emplace_back(read, &getBMutable());
2165+
effects.emplace_back(read, &getCMutable());
2166+
for (auto &regionArg : getElemwiseInputsMutable())
2167+
effects.emplace_back(read, &regionArg);
2168+
}
2169+
2170+
//===-----------------------------------------------------===//
2171+
// AttentionOp
2172+
//===-----------------------------------------------------===//
2173+
2174+
LogicalResult AttentionOp::verify() {
2175+
return verifyAttentionOp(*this, getQueries(), getKeys(), getValues(),
2176+
getCurrentSeqLen(), getQTransposed(),
2177+
getKTransposed(), getVTransposed());
2178+
}
2179+
21422180
void AttentionOp::getEffects(
21432181
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
21442182
auto *read = MemoryEffects::Read::get();

0 commit comments

Comments
 (0)