Skip to content

Commit a5dd470

Browse files
committed
Add Gemm+Elementwise+Gemm support
1 parent a8ac1fb commit a5dd470

File tree

18 files changed

+1581
-645
lines changed

18 files changed

+1581
-645
lines changed

mlir/include/mlir/Dialect/Rock/IR/Rock.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ class FusionRoot : public TraitBase<ConcreteType, FusionRoot> {};
4949
} // namespace OpTrait
5050
} // namespace mlir
5151

52-
// Following ifdef could be used to change
53-
// the attention operator to be a fused gemm-gemm
54-
// kernel for debugging purposes. This will also
55-
// adjust the test harness to verify the same as well
56-
// #define ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX
57-
5852
namespace mlir {
5953
namespace rock {
6054
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;
6363
def KernelTypeConvBwdWeight : I32EnumAttrCase<"ConvBwdWeight", 2>;
6464
def KernelTypeGemm : I32EnumAttrCase<"Gemm", 3>;
6565
def KernelTypeAttention : I32EnumAttrCase<"Attention", 4>;
66+
def KernelTypeGemmElementwiseGemm : I32EnumAttrCase<"GemmElementwiseGemm", 5>;
6667

6768
def KernelType : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
6869
[KernelTypeConv, KernelTypeConvBwdData,
6970
KernelTypeConvBwdWeight, KernelTypeGemm,
70-
KernelTypeAttention]>;
71+
KernelTypeAttention, KernelTypeGemmElementwiseGemm]>;
7172

7273
/// TransformType
7374
def PassThrough : I32EnumAttrCase<"PassThrough", 0>;

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def Rock_AttentionOp :
236236
and the number of compute units (numCu) available. The parameters
237237
`gridSize`, and `blockSize` are optional as they can be inferred by
238238
a tuning process or a heuristic, but they must be set before the `attention` is
239-
lowered into the `gridwise_attention` stage of the code generation pipeline.
239+
lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
240240

241241
`features` specifies what hardware features can be used in the generated code.
242242
}];
@@ -255,6 +255,54 @@ def Rock_AttentionOp :
255255
}];
256256
}
257257

258+
def Rock_GemmElementwiseGemmOp:
259+
Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
260+
Arguments<(ins
261+
TensorOrMemRefOf<GemmInputTypes>:$a,
262+
TensorOrMemRefOf<GemmInputTypes>:$b,
263+
TensorOrMemRefOf<GemmInputTypes>:$c,
264+
Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
265+
TensorOrMemRefOf<GemmAccumulatorTypes>:$out,
266+
UnitAttr:$aTransposed,
267+
UnitAttr:$bTransposed,
268+
UnitAttr:$cTransposed,
269+
UnitAttr:$oTransposed,
270+
StrAttr:$arch,
271+
Rock_GemmFeaturesAttr:$features,
272+
OptionalAttr<I32Attr>:$numCU,
273+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
274+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
275+
I32Attr:$firstGemmIdx
276+
)>,
277+
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
278+
let summary = "GEMM-elementwise-GEMM operation";
279+
let description = [{
280+
Performs the operation out = (a * b) * c.
281+
282+
This operation performs fused GEMM-elementwise-GEMM.
283+
284+
Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
285+
and the number of compute units (numCu) available. The parameters
286+
`gridSize`, and `blockSize` are optional as they can be inferred by
287+
a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
288+
lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
289+
290+
`features` specifies what hardware features can be used in the generated code.
291+
}];
292+
let hasVerifier = 1;
293+
let regions = (region AnyRegion:$preSecondGemmBody);
294+
let assemblyFormat = [{
295+
`{` `\n`
296+
` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n`
297+
(`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)?
298+
(`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n`
299+
`}` attr-dict (`->` type($result)^)?
300+
}];
301+
let extraClassDeclaration = [{
302+
::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); }
303+
}];
304+
}
305+
258306
def Rock_InitKernelOp :
259307
Rock_Op<"init_kernel", []>,
260308
Arguments<(ins AnyTensorOrMemRef:$buffer,
@@ -434,12 +482,12 @@ def Rock_GridwiseGemmAccelOp :
434482
// gridwise_attention_accel
435483
def Rock_GridwiseAttentionAccelOp :
436484
Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
437-
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
438-
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
439-
MemRefRankOf<[F32, F16, BF16,], [3]>:$values,
440-
Variadic<TensorOrMemRefOf<[F32, F16, BF16, I8]>>:$preSoftmaxElemWiseInputs,
485+
Arguments<(ins MemRefRankOf<GemmInputTypes, [3]>:$queries,
486+
MemRefRankOf<GemmInputTypes, [3]>:$keys,
487+
MemRefRankOf<GemmInputTypes, [3]>:$values,
488+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
441489
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
442-
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
490+
MemRefRankOf<GemmAccumulatorTypes, [3]>:$out,
443491
StrAttr:$arch,
444492
Rock_GemmFeaturesAttr:$features,
445493
I32Attr:$blockSize,
@@ -449,7 +497,8 @@ def Rock_GridwiseAttentionAccelOp :
449497
OptionalAttr<IndexAttr>:$prePadG0N,
450498
RockAccelTuningParamAttrInterface:$params0,
451499
RockAccelTuningParamAttrInterface:$params1,
452-
I32Attr:$firstGemmIdx)> {
500+
I32Attr:$firstGemmIdx,
501+
DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {
453502
let summary = "Gridwise attention accelerated version";
454503
let description = [{
455504
The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.

0 commit comments

Comments
 (0)