@@ -225,9 +225,18 @@ def Rock_AttentionOp
225225 Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
226226 let summary = "Attention operation of transformer models";
227227 let description = [{
228- Performs the operation out = SOFTMAX(queries * keys) * values.
228+ Performs the operation out = SOFTMAX(preSoftmaxBody( queries * keys, preSoftmaxElemWiseInputs) ) * values.
229229
230- This operation performs attention mechanism of transformer models.
230+ This operation performs attention mechanism of transformer models. There is an optional element-wise
231+ fusion just before the softmax, defined by `preSoftmaxBody` with inputs `preSoftmaxElemWiseInputs`.
232+
233+ If none of the `transposed` attributes are set, then `queries` is [G] x seq_q x head_qk,
234+ `keys` is [G] x head_qk x seq_k, `values` is [G] x seq_k x head_v and `out` is [G] x seq_q x head_v,
235+ where G is the optional group dimension (which is assumed to be 1 if not set).
236+
237+ The transpose attributes allow for the non-group dimensions of the matrix to be
238+ transposed. For example, if `qTransposed` is set, then the argument `queries` should be
239+ a [G] x head_qk x seq_q memory.
231240
232241 Those creating a `rock.attention` must specify the GPU architecture being targetted
233242 and the number of compute units (numCu) available. The parameters
@@ -268,9 +277,18 @@ def Rock_GemmElementwiseGemmOp
268277 Results<(outs Optional<TensorOf<[F32]>>:$result)> {
269278 let summary = "GEMM-elementwise-GEMM operation";
270279 let description = [{
271- Performs the operation out = (a * b) * c.
280+ Performs the operation out = preSecondGemmBody(a * b, elemwiseInputs) * c.
281+
282+ This operation performs fused GEMM-elementwise-GEMM. There is an optional element-wise
283+ fusion just before the second GEMM, defined by `preSecondGemmBody` with inputs `elemwiseInputs`.
272284
273- This operation performs fused GEMM-elementwise-GEMM.
285+ If none of the `transposed` attributes are set, then `a` is [G] x M x K,
286+ `b` is [G] x K x N, `c` is [G] x N x O and `out` is [G] x M x O, where G is the
287+ optional group dimension (which is assumed to be 1 if not set).
288+
289+ The transpose attributes allow for the non-group dimensions of the matrix to be
290+ transposed. For example, if `aTransposed` is set, then the argument `a` should be
291+ a [G] x K x M memory.
274292
275293 Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
276294 and the number of compute units (numCu) available. The parameters
0 commit comments