@@ -213,7 +213,7 @@ def Rock_AttentionOp :
213213 TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214214 Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215215 Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216- TensorOrMemRefOf<[F32, BF16, F16 ]>:$out,
216+ TensorOrMemRefOf<[F32, F16, BF16 ]>:$out,
217217 UnitAttr:$qTransposed,
218218 UnitAttr:$kTransposed,
219219 UnitAttr:$vTransposed,
@@ -228,15 +228,15 @@ def Rock_AttentionOp :
228228 Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
229229 let summary = "Attention operation of transformer models";
230230 let description = [{
231- Performs the operation out = SOFTMAX(( queries * keys) .* scale ) * values.
231+ Performs the operation out = SOFTMAX(queries * keys) * values.
232232
233233 This operation performs attention mechanism of transformer models.
234234
235235 Those creating a `rock.attention` must specify the GPU architecture being targetted
236236 and the number of compute units (numCu) available. The parameters
237237 `gridSize`, and `blockSize` are optional as they can be inferred by
238238 a tuning process or a heuristic, but they must be set before the `attention` is
239- lowered into the `gridwise_attention ` stage of the code generation pipeline.
239+ lowered into the `gridwise_attention_accel ` stage of the code generation pipeline.
240240
241241 `features` specifies what hardware features can be used in the generated code.
242242 }];
@@ -255,6 +255,55 @@ def Rock_AttentionOp :
255255 }];
256256}
257257
258+ def Rock_GemmElementwiseGemmOp:
259+ Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
260+ AllElementTypesMatch<["a", "b", "c"]>,
261+ Arguments<(ins
262+ TensorOrMemRefOf<[F32]>:$a,
263+ TensorOrMemRefOf<[F32]>:$b,
264+ TensorOrMemRefOf<[F32]>:$c,
265+ Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
266+ TensorOrMemRefOf<[F32]>:$out,
267+ UnitAttr:$aTransposed,
268+ UnitAttr:$bTransposed,
269+ UnitAttr:$cTransposed,
270+ UnitAttr:$oTransposed,
271+ StrAttr:$arch,
272+ Rock_GemmFeaturesAttr:$features,
273+ OptionalAttr<I32Attr>:$numCU,
274+ OptionalAttr<RockTuningParamAttrInterface>:$params0,
275+ OptionalAttr<RockTuningParamAttrInterface>:$params1,
276+ I32Attr:$firstGemmIdx
277+ )>,
278+ Results<(outs Optional<TensorOf<[F32]>>:$result)> {
279+ let summary = "GEMM-elementwise-GEMM operation";
280+ let description = [{
281+ Performs the operation out = (a * b) * c.
282+
283+ This operation performs fused GEMM-elementwise-GEMM.
284+
285+ Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
286+ and the number of compute units (numCu) available. The parameters
287+ `gridSize`, and `blockSize` are optional as they can be inferred by
288+ a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
289+ lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
290+
291+ `features` specifies what hardware features can be used in the generated code.
292+ }];
293+ let hasVerifier = 1;
294+ let regions = (region AnyRegion:$preSecondGemmBody);
295+ let assemblyFormat = [{
296+ `{` `\n`
297+ ` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n`
298+ (`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)?
299+ (`tr` $oTransposed^)? $out `=` `ab` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n`
300+ `}` attr-dict (`->` type($result)^)?
301+ }];
302+ let extraClassDeclaration = [{
303+ ::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); }
304+ }];
305+ }
306+
258307def Rock_InitKernelOp :
259308 Rock_Op<"init_kernel", []>,
260309 Arguments<(ins AnyTensorOrMemRef:$buffer,
@@ -449,7 +498,8 @@ def Rock_GridwiseAttentionAccelOp :
449498 OptionalAttr<IndexAttr>:$prePadG0N,
450499 RockAccelTuningParamAttrInterface:$params0,
451500 RockAccelTuningParamAttrInterface:$params1,
452- I32Attr:$firstGemmIdx)> {
501+ I32Attr:$firstGemmIdx,
502+ DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {
453503 let summary = "Gridwise attention accelerated version";
454504 let description = [{
455505 The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.
0 commit comments