Skip to content

Commit 4821b84

Browse files
committed
Specialize problem string for attention and gemm+gemm
1 parent d7ad776 commit 4821b84

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -582,22 +582,27 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
582582
ArrayRef<int64_t> vShape = cast<MemRefType>(gemmGemmOp.getCType()).getShape();
583583
int64_t g = qShape[0];
584584

585+
bool isAttention = isa<AttentionOp>(gemmGemmOp);
586+
585587
Type elemTypeQ = cast<MemRefType>(gemmGemmOp.getAType()).getElementType();
586588
problemOS << "-t ";
587589
if (elemTypeQ.isF32()) {
588590
problemOS << "f32" << sep;
589-
} else if (elemTypeQ.isF16()) {
591+
} else if (elemTypeQ.isF16() && isAttention) {
590592
problemOS << "f16" << sep;
591-
} else if (elemTypeQ.isBF16()) {
593+
} else if (elemTypeQ.isBF16() && isAttention) {
592594
problemOS << "bf16" << sep;
593-
} else if (elemTypeQ.isInteger(8)) {
595+
} else if (elemTypeQ.isInteger(8) && isAttention) {
594596
problemOS << "i8" << sep;
595597
} else {
596598
return gemmGemmOp.emitError("invalid type:") << elemTypeQ << "\n";
597599
}
598600

599601
// TransQ
600-
problemOS << "-transQ ";
602+
if (isAttention)
603+
problemOS << "-transQ ";
604+
else
605+
problemOS << "-transA ";
601606
if (gemmGemmOp.getTransposedA()) {
602607
seqLenQ = qShape[2];
603608
headDimQK = qShape[1];
@@ -609,7 +614,10 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
609614
}
610615

611616
// TransK
612-
problemOS << "-transK ";
617+
if (isAttention)
618+
problemOS << "-transK ";
619+
else
620+
problemOS << "-transB ";
613621
if (gemmGemmOp.getTransposedB()) {
614622
seqLenK = kShape[1];
615623
problemOS << "true" << sep;
@@ -619,7 +627,10 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
619627
}
620628

621629
// TransV
622-
problemOS << "-transV ";
630+
if (isAttention)
631+
problemOS << "-transV ";
632+
else
633+
problemOS << "-transC ";
623634
if (gemmGemmOp.getTransposedC()) {
624635
headDimV = vShape[1];
625636
problemOS << "true" << sep;
@@ -636,10 +647,17 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
636647
problemOS << "false" << sep;
637648

638649
problemOS << "-g " << g << sep;
639-
problemOS << "-seq_len_q " << seqLenQ << sep;
640-
problemOS << "-seq_len_k " << seqLenK << sep;
641-
problemOS << "-head_dim_qk " << headDimQK << sep;
642-
problemOS << "-head_dim_v " << headDimV;
650+
if (isAttention) {
651+
problemOS << "-seq_len_q " << seqLenQ << sep;
652+
problemOS << "-seq_len_k " << seqLenK << sep;
653+
problemOS << "-head_dim_qk " << headDimQK << sep;
654+
problemOS << "-head_dim_v " << headDimV;
655+
} else {
656+
problemOS << "-m " << seqLenQ << sep;
657+
problemOS << "-n " << seqLenK << sep;
658+
problemOS << "-k " << headDimQK << sep;
659+
problemOS << "-gemmO " << headDimV;
660+
}
643661
return success();
644662
}
645663

0 commit comments

Comments
 (0)