@@ -205,38 +205,33 @@ def Rock_ReduceOp :
205205 }];
206206}
207207
208- def Rock_AttentionOp :
209- Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
210- Arguments<(ins
211- TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212- TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213- TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214- Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215- Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216- TensorOrMemRefOf<[F32, BF16, F16]>:$out,
217- UnitAttr:$qTransposed,
218- UnitAttr:$kTransposed,
219- UnitAttr:$vTransposed,
220- UnitAttr:$oTransposed,
221- StrAttr:$arch,
222- Rock_GemmFeaturesAttr:$features,
223- OptionalAttr<I32Attr>:$numCU,
224- OptionalAttr<RockTuningParamAttrInterface>:$params0,
225- OptionalAttr<RockTuningParamAttrInterface>:$params1,
226- I32Attr:$firstGemmIdx
227- )>,
228- Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
208+ def Rock_AttentionOp
209+ : Rock_Op<"attention", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
210+ RockFusionRoot, AttrSizedOperandSegments]>,
211+ Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212+ TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213+ TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214+ Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215+ Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216+ TensorOrMemRefOf<[F32, F16, BF16]>:$out, UnitAttr:$qTransposed,
217+ UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed,
218+ StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
219+ OptionalAttr<I32Attr>:$numCU,
220+ OptionalAttr<RockTuningParamAttrInterface>:$params0,
221+ OptionalAttr<RockTuningParamAttrInterface>:$params1,
222+ I32Attr:$firstGemmIdx)>,
223+ Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
229224 let summary = "Attention operation of transformer models";
230225 let description = [{
231- Performs the operation out = SOFTMAX(( queries * keys) .* scale ) * values.
226+ Performs the operation out = SOFTMAX(queries * keys) * values.
232227
233228 This operation performs attention mechanism of transformer models.
234229
235230 Those creating a `rock.attention` must specify the GPU architecture being targetted
236231 and the number of compute units (numCu) available. The parameters
237232 `gridSize`, and `blockSize` are optional as they can be inferred by
238233 a tuning process or a heuristic, but they must be set before the `attention` is
239- lowered into the `gridwise_attention ` stage of the code generation pipeline.
234+ lowered into the `gridwise_attention_accel ` stage of the code generation pipeline.
240235
241236 `features` specifies what hardware features can be used in the generated code.
242237 }];
@@ -255,6 +250,50 @@ def Rock_AttentionOp :
255250 }];
256251}
257252
253+ def Rock_GemmElementwiseGemmOp
254+ : Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<
255+ MemoryEffectsOpInterface>,
256+ RockFusionRoot]>,
257+ AllElementTypesMatch<["a", "b", "c"]>,
258+ Arguments<(ins TensorOrMemRefOf<[F32]>:$a, TensorOrMemRefOf<[F32]>:$b,
259+ TensorOrMemRefOf<[F32]>:$c,
260+ Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
261+ TensorOrMemRefOf<[F32]>:$out, UnitAttr:$aTransposed,
262+ UnitAttr:$bTransposed, UnitAttr:$cTransposed, UnitAttr:$oTransposed,
263+ StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
264+ OptionalAttr<I32Attr>:$numCU,
265+ OptionalAttr<RockTuningParamAttrInterface>:$params0,
266+ OptionalAttr<RockTuningParamAttrInterface>:$params1,
267+ I32Attr:$firstGemmIdx)>,
268+ Results<(outs Optional<TensorOf<[F32]>>:$result)> {
269+ let summary = "GEMM-elementwise-GEMM operation";
270+ let description = [{
271+ Performs the operation out = (a * b) * c.
272+
273+ This operation performs fused GEMM-elementwise-GEMM.
274+
275+ Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
276+ and the number of compute units (numCu) available. The parameters
277+ `gridSize`, and `blockSize` are optional as they can be inferred by
278+ a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
279+ lowered into the `gridwise_attention_accel` stage of the code generation pipeline.
280+
281+ `features` specifies what hardware features can be used in the generated code.
282+ }];
283+ let hasVerifier = 1;
284+ let regions = (region AnyRegion:$preSecondGemmBody);
285+ let assemblyFormat = [{
286+ `{` `\n`
287+ ` ` `ab` `=` (`tr` $aTransposed^)? $a `*` (`tr` $bTransposed^)? $b `:` type($a) `,` type($b) `\n`
288+ (`ab` `=` `elementwise` (`otherIns` `(` $elemwiseInputs^ `:` type($elemwiseInputs) `)`)? $preSecondGemmBody^ `\n`)?
289+ (`tr` $oTransposed^)? $out `=` `ab` `*` (`tr` $cTransposed^)? $c `:` type($c) `->` type($out) `\n`
290+ `}` attr-dict (`->` type($result)^)?
291+ }];
292+ let extraClassDeclaration = [{
293+ ::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperands().back(); }
294+ }];
295+ }
296+
258297def Rock_InitKernelOp :
259298 Rock_Op<"init_kernel", []>,
260299 Arguments<(ins AnyTensorOrMemRef:$buffer,
@@ -432,24 +471,23 @@ def Rock_GridwiseGemmAccelOp :
432471}
433472
434473// gridwise_attention_accel
435- def Rock_GridwiseAttentionAccelOp :
436- Rock_Op<"gridwise_attention_accel", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
437- Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
438- MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
439- MemRefRankOf<[F32, F16, BF16,], [3]>:$values,
440- Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
441- Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
442- MemRefRankOf<[F32, F16, BF16], [3]>:$out,
443- StrAttr:$arch,
444- Rock_GemmFeaturesAttr:$features,
445- I32Attr:$blockSize,
446- I32Attr:$gridSize,
447- UnitAttr:$disableQBypassLDS,
448- OptionalAttr<IndexAttr>:$prePadG0M,
449- OptionalAttr<IndexAttr>:$prePadG0N,
450- RockAccelTuningParamAttrInterface:$params0,
451- RockAccelTuningParamAttrInterface:$params1,
452- I32Attr:$firstGemmIdx)> {
474+ def Rock_GridwiseAttentionAccelOp
475+ : Rock_Op<"gridwise_attention_accel",
476+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
477+ RockFusionRoot, AttrSizedOperandSegments]>,
478+ Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
479+ MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
480+ MemRefRankOf<[F32, F16, BF16, ], [3]>:$values,
481+ Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
482+ Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
483+ MemRefRankOf<[F32, F16, BF16], [3]>:$out, StrAttr:$arch,
484+ Rock_GemmFeaturesAttr:$features, I32Attr:$blockSize,
485+ I32Attr:$gridSize, UnitAttr:$disableQBypassLDS,
486+ OptionalAttr<IndexAttr>:$prePadG0M,
487+ OptionalAttr<IndexAttr>:$prePadG0N,
488+ RockAccelTuningParamAttrInterface:$params0,
489+ RockAccelTuningParamAttrInterface:$params1, I32Attr:$firstGemmIdx,
490+ DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {
453491 let summary = "Gridwise attention accelerated version";
454492 let description = [{
455493 The `rock.gridwise_attention_accel` op computes gridwise attention with acceleration.
0 commit comments