Skip to content

Commit e487edc

Browse files
committed
Add Gemm+Elementwise+Gemm support
1 parent 75be71c commit e487edc

File tree

51 files changed

+2279
-647
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2279
-647
lines changed

mlir/include/mlir/Dialect/Rock/IR/Rock.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ class FusionRoot : public TraitBase<ConcreteType, FusionRoot> {};
4949
} // namespace OpTrait
5050
} // namespace mlir
5151

52-
// Following ifdef could be used to change
53-
// the attention operator to be a fused gemm-gemm
54-
// kernel for debugging purposes. This will also
55-
// adjust the test harness to verify the same as well
56-
// #define ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX
57-
5852
namespace mlir {
5953
namespace rock {
6054
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;
6363
def KernelTypeConvBwdWeight : I32EnumAttrCase<"ConvBwdWeight", 2>;
6464
def KernelTypeGemm : I32EnumAttrCase<"Gemm", 3>;
6565
def KernelTypeAttention : I32EnumAttrCase<"Attention", 4>;
66+
def KernelTypeGemmElementwiseGemm : I32EnumAttrCase<"GemmElementwiseGemm", 5>;
6667

6768
def KernelType : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
6869
[KernelTypeConv, KernelTypeConvBwdData,
6970
KernelTypeConvBwdWeight, KernelTypeGemm,
70-
KernelTypeAttention]>;
71+
KernelTypeAttention, KernelTypeGemmElementwiseGemm]>;
7172

7273
/// TransformType
7374
def PassThrough : I32EnumAttrCase<"PassThrough", 0>;

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def Rock_AttentionOp :
213213
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214214
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215215
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216-
TensorOrMemRefOf<[F32, BF16, F16]>:$out,
216+
TensorOrMemRefOf<[F32, F16, BF16]>:$out,
217217
UnitAttr:$qTransposed,
218218
UnitAttr:$kTransposed,
219219
UnitAttr:$vTransposed,
@@ -228,15 +228,15 @@ def Rock_AttentionOp :
228228
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
229229
let summary = "Attention operation of transformer models";
230230
let description = [{
231-
Performs the operation out = SOFTMAX((queries * keys) .* scale) * values.
231+
Performs the operation out = SOFTMAX(queries * keys) * values.
232232

233233
This operation performs attention mechanism of transformer models.
234234

235235
Those creating a `rock.attention` must specify the GPU architecture being targetted
236236
and the number of compute units (numCu) available. The parameters
237237
`gridSize`, and `blockSize` are optional as they can be inferred by
238238
a tuning process or a heuristic, but they must be set before the `attention` is
239-
lowered into the `gridwise_attention` stage of the code generation pipeline.
239+
lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
240240

241241
`features` specifies what hardware features can be used in the generated code.
242242
}];
@@ -255,6 +255,55 @@ def Rock_AttentionOp :
255255
}];
256256
}
257257

258+
def Rock_GemmElementwiseGemmOp:
259+
Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
260+
AllElementTypesMatch<["a", "b", "c"]>,
261+
Arguments<(ins
262+
TensorOrMemRefOf<[F32]>:$a,
263+
TensorOrMemRefOf<[F32]>:$b,
264+
TensorOrMemRefOf<[F32]>:$c,
265+
Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
266+
TensorOrMemRefOf<[F32]>:$out,
267+
UnitAttr:$aTransposed,
268+
UnitAttr:$bTransposed,
269+
UnitAttr:$cTransposed,
270+
UnitAttr:$oTransposed,
271+
StrAttr:$arch,
272+
Rock_GemmFeaturesAttr:$features,
273+
OptionalAttr<I32Attr>:$numCU,
274+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
275+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
276+
I32Attr:$firstGemmIdx
277+
)>,
278+
Results<(outs Optional<TensorOf<[F32]>>:$result)> {
279+
let summary = "GEMM-elementwise-GEMM operation";
280+
let description = [{
281+
Performs the operation out = (a * b) * c.
282+
283+
This operation performs fused GEMM-elementwise-GEMM.
284+
285+
Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
286+
and the number of compute units (numCu) available. The parameters
287+
`gridSize`, and `blockSize` are optional as they can be inferred by
288+
a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
289+
lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
290+
291+
`features` specifies what hardware features can be used in the generated code.
292+
}];
293+
let hasVerifier = 1;
294+
let regions = (region AnyRegion:$preSecondGemmBody);
295+
let assemblyFormat = [{
296+
`{` `\n`
297+
` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n`
298+
(`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)?
299+
(`tr` $oTransposed^)? $out `=` `ab` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n`
300+
`}` attr-dict (`->` type($result)^)?
301+
}];
302+
let extraClassDeclaration = [{
303+
::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); }
304+
}];
305+
}
306+
258307
def Rock_InitKernelOp :
259308
Rock_Op<"init_kernel", []>,
260309
Arguments<(ins AnyTensorOrMemRef:$buffer,
@@ -449,7 +498,8 @@ def Rock_GridwiseAttentionAccelOp :
449498
OptionalAttr<IndexAttr>:$prePadG0N,
450499
RockAccelTuningParamAttrInterface:$params0,
451500
RockAccelTuningParamAttrInterface:$params1,
452-
I32Attr:$firstGemmIdx)> {
501+
I32Attr:$firstGemmIdx,
502+
DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {
453503
let summary = "Gridwise attention accelerated version";
454504
let description = [{
455505
The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.

0 commit comments

Comments
 (0)