@@ -236,7 +236,7 @@ def Rock_AttentionOp :
236236 and the number of compute units (numCu) available. The parameters
237237 `gridSize`, and `blockSize` are optional as they can be inferred by
238238 a tuning process or a heuristic, but they must be set before the `attention` is
239- lowered into the `gridwise_attention ` stage of the code generation pipeline.
239+ lowered into the `gridwise_attention_accel ` stage of the code generation pipeline.
240240
241241 `features` specifies what hardware features can be used in the generated code.
242242 }];
@@ -255,6 +255,54 @@ def Rock_AttentionOp :
255255 }];
256256}
257257
258+ def Rock_GemmElementwiseGemmOp:
259+ Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
260+ Arguments<(ins
261+ TensorOrMemRefOf<GemmInputTypes>:$a,
262+ TensorOrMemRefOf<GemmInputTypes>:$b,
263+ TensorOrMemRefOf<GemmInputTypes>:$c,
264+ Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
265+ TensorOrMemRefOf<GemmAccumulatorTypes>:$out,
266+ UnitAttr:$aTransposed,
267+ UnitAttr:$bTransposed,
268+ UnitAttr:$cTransposed,
269+ UnitAttr:$oTransposed,
270+ StrAttr:$arch,
271+ Rock_GemmFeaturesAttr:$features,
272+ OptionalAttr<I32Attr>:$numCU,
273+ OptionalAttr<RockTuningParamAttrInterface>:$params0,
274+ OptionalAttr<RockTuningParamAttrInterface>:$params1,
275+ I32Attr:$firstGemmIdx
276+ )>,
277+ Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
278+ let summary = "GEMM-elementwise-GEMM operation";
279+ let description = [{
280+ Performs the operation out = (a * b) * c.
281+
282+ This operation performs fused GEMM-elementwise-GEMM.
283+
284+ Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
285+ and the number of compute units (numCu) available. The parameters
286+ `gridSize`, and `blockSize` are optional as they can be inferred by
287+ a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
288+ lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
289+
290+ `features` specifies what hardware features can be used in the generated code.
291+ }];
292+ let hasVerifier = 1;
293+ let regions = (region AnyRegion:$preSecondGemmBody);
294+ let assemblyFormat = [{
295+ `{` `\n`
296+ ` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n`
297+ (`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)?
298+ (`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n`
299+ `}` attr-dict (`->` type($result)^)?
300+ }];
301+ let extraClassDeclaration = [{
302+ ::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); }
303+ }];
304+ }
305+
258306def Rock_InitKernelOp :
259307 Rock_Op<"init_kernel", []>,
260308 Arguments<(ins AnyTensorOrMemRef:$buffer,
@@ -434,12 +482,12 @@ def Rock_GridwiseGemmAccelOp :
434482// gridwise_attention_accel
435483def Rock_GridwiseAttentionAccelOp :
436484 Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
437- Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8] , [3]>:$queries,
438- MemRefRankOf<[F32, F16, BF16, I8] , [3]>:$keys,
439- MemRefRankOf<[F32, F16, BF16,] , [3]>:$values,
485+ Arguments<(ins MemRefRankOf<GemmInputTypes , [3]>:$queries,
486+ MemRefRankOf<GemmInputTypes , [3]>:$keys,
487+ MemRefRankOf<GemmInputTypes , [3]>:$values,
440488 Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
441489 Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
442- MemRefRankOf<[F32, F16, BF16] , [3]>:$out,
490+ MemRefRankOf<GemmAccumulatorTypes , [3]>:$out,
443491 StrAttr:$arch,
444492 Rock_GemmFeaturesAttr:$features,
445493 I32Attr:$blockSize,
@@ -449,7 +497,8 @@ def Rock_GridwiseAttentionAccelOp :
449497 OptionalAttr<IndexAttr>:$prePadG0N,
450498 RockAccelTuningParamAttrInterface:$params0,
451499 RockAccelTuningParamAttrInterface:$params1,
452- I32Attr:$firstGemmIdx)> {
500+ I32Attr:$firstGemmIdx,
501+ DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {
453502 let summary = "Gridwise attention accelerated version";
454503 let description = [{
455504 The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.
0 commit comments