Skip to content

Commit bde9061

Browse files
committed
Addressing PR comments
1 parent 6fb9726 commit bde9061

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,6 @@ def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> {
178178
InterfaceMethod<
179179
/*desc=*/[{
180180
Set the tuning parameters attribute of the first GEMM
181-
182-
This is needed for --affix-tuning-params to work and can go away if it does
183181
}],
184182
/*retType=*/"void",
185183
/*methodName=*/"setGemm0ParamsAttr",
@@ -192,8 +190,6 @@ def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> {
192190
InterfaceMethod<
193191
/*desc=*/[{
194192
Set the tuning parameters attribute of the second GEMM
195-
196-
This is needed for --affix-tuning-params to work and can go away if it does
197193
}],
198194
/*retType=*/"void",
199195
/*methodName=*/"setGemm1ParamsAttr",

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,18 @@ def Rock_AttentionOp
225225
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
226226
let summary = "Attention operation of transformer models";
227227
let description = [{
228-
Performs the operation out = SOFTMAX(queries * keys) * values.
228+
Performs the operation out = SOFTMAX(preSoftmaxBody(queries * keys, preSoftmaxElemWiseInputs)) * values.
229229

230-
This operation performs attention mechanism of transformer models.
230+
This operation performs attention mechanism of transformer models. There is an optional element-wise
231+
fusion just before the softmax, defined by `preSoftmaxBody` with inputs `preSoftmaxElemWiseInputs`.
232+
233+
If none of the `transposed` attributes are set, then `queries` is [G] x seq_q x head_qk,
234+
`keys` is [G] x head_qk x seq_k, `values` is [G] x seq_k x head_v and `out` is [G] x seq_q x head_v,
235+
where G is the optional group dimension (which is assumed to be 1 if not set).
236+
237+
The transpose attributes allow for the non-group dimensions of the matrix to be
238+
transposed. For example, if `qTransposed` is set, then the argument `queries` should be
239+
a [G] x head_qk x seq_q memory.
231240

232241
Those creating a `rock.attention` must specify the GPU architecture being targetted
233242
and the number of compute units (numCu) available. The parameters
@@ -268,9 +277,18 @@ def Rock_GemmElementwiseGemmOp
268277
Results<(outs Optional<TensorOf<[F32]>>:$result)> {
269278
let summary = "GEMM-elementwise-GEMM operation";
270279
let description = [{
271-
Performs the operation out = (a * b) * c.
280+
Performs the operation out = preSecondGemmBody(a * b, elemwiseInputs) * c.
281+
282+
This operation performs fused GEMM-elementwise-GEMM. There is an optional element-wise
283+
fusion just before the second GEMM, defined by `preSecondGemmBody` with inputs `elemwiseInputs`.
272284

273-
This operation performs fused GEMM-elementwise-GEMM.
285+
If none of the `transposed` attributes are set, then `a` is [G] x M x K,
286+
`b` is [G] x K x N, `c` is [G] x N x O and `out` is [G] x M x O, where G is the
287+
optional group dimension (which is assumed to be 1 if not set).
288+
289+
The transpose attributes allow for the non-group dimensions of the matrix to be
290+
transposed. For example, if `aTransposed` is set, then the argument `a` should be
291+
a [G] x K x M memory.
274292

275293
Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
276294
and the number of compute units (numCu) available. The parameters

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,8 +2124,8 @@ GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
21242124
return GemmGemmSize(g, m, k, n, o);
21252125
}
21262126

2127-
static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
2128-
Value currentSeqLen) {
2127+
static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
2128+
Value currentSeqLen) {
21292129
ShapedType qType = cast<ShapedType>(op.getAType());
21302130
int64_t qBatchDim = qType.getShape().size() == 3 ? qType.getShape()[0] : 1;
21312131
ArrayRef<int64_t> qLastDims = qType.getShape().slice(qType.getRank() - 2);
@@ -2194,7 +2194,7 @@ static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
21942194
}
21952195

21962196
LogicalResult GemmElementwiseGemmOp::verify() {
2197-
return verifyAttentionOp(*this, /*currentSeqLen=*/nullptr);
2197+
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr);
21982198
}
21992199

22002200
void GemmElementwiseGemmOp::getEffects(
@@ -2256,7 +2256,7 @@ GemmGemmSize AttentionOp::getGemmGemmSize() {
22562256
}
22572257

22582258
LogicalResult AttentionOp::verify() {
2259-
return verifyAttentionOp(*this, getCurrentSeqLen());
2259+
return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen());
22602260
}
22612261

22622262
void AttentionOp::getEffects(

mlir/tools/rocmlir-gen/rocmlir-gen.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,8 +2425,9 @@ static TosaOp createOpAndInfer(OpBuilder &builder, Location loc, Type elemType,
24252425
return op;
24262426
}
24272427

2428-
Value addTensorArgToBlock(OpBuilder &builder, Location loc,
2429-
Block *preSoftmaxElemwiseBlock, Value funcArg) {
2428+
static Value addTensorArgToBlock(OpBuilder &builder, Location loc,
2429+
Block *preSoftmaxElemwiseBlock,
2430+
Value funcArg) {
24302431
ShapedType funcArgType = cast<ShapedType>(funcArg.getType());
24312432
Value funcArgMemRef = preSoftmaxElemwiseBlock->addArgument(
24322433
MemRefType::get(funcArgType.getShape(), funcArgType.getElementType()),

0 commit comments

Comments
 (0)