@@ -582,22 +582,27 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
582582 ArrayRef<int64_t > vShape = cast<MemRefType>(gemmGemmOp.getCType ()).getShape ();
583583 int64_t g = qShape[0 ];
584584
585+ bool isAttention = isa<AttentionOp>(gemmGemmOp);
586+
585587 Type elemTypeQ = cast<MemRefType>(gemmGemmOp.getAType ()).getElementType ();
586588 problemOS << " -t " ;
587589 if (elemTypeQ.isF32 ()) {
588590 problemOS << " f32" << sep;
589- } else if (elemTypeQ.isF16 ()) {
591+ } else if (elemTypeQ.isF16 () && isAttention ) {
590592 problemOS << " f16" << sep;
591- } else if (elemTypeQ.isBF16 ()) {
593+ } else if (elemTypeQ.isBF16 () && isAttention ) {
592594 problemOS << " bf16" << sep;
593- } else if (elemTypeQ.isInteger (8 )) {
595+ } else if (elemTypeQ.isInteger (8 ) && isAttention ) {
594596 problemOS << " i8" << sep;
595597 } else {
596598 return gemmGemmOp.emitError (" invalid type:" ) << elemTypeQ << " \n " ;
597599 }
598600
599601 // TransQ
600- problemOS << " -transQ " ;
602+ if (isAttention)
603+ problemOS << " -transQ " ;
604+ else
605+ problemOS << " -transA " ;
601606 if (gemmGemmOp.getTransposedA ()) {
602607 seqLenQ = qShape[2 ];
603608 headDimQK = qShape[1 ];
@@ -609,7 +614,10 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
609614 }
610615
611616 // TransK
612- problemOS << " -transK " ;
617+ if (isAttention)
618+ problemOS << " -transK " ;
619+ else
620+ problemOS << " -transB " ;
613621 if (gemmGemmOp.getTransposedB ()) {
614622 seqLenK = kShape [1 ];
615623 problemOS << " true" << sep;
@@ -619,7 +627,10 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
619627 }
620628
621629 // TransV
622- problemOS << " -transV " ;
630+ if (isAttention)
631+ problemOS << " -transV " ;
632+ else
633+ problemOS << " -transC " ;
623634 if (gemmGemmOp.getTransposedC ()) {
624635 headDimV = vShape[1 ];
625636 problemOS << " true" << sep;
@@ -636,10 +647,17 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
636647 problemOS << " false" << sep;
637648
638649 problemOS << " -g " << g << sep;
639- problemOS << " -seq_len_q " << seqLenQ << sep;
640- problemOS << " -seq_len_k " << seqLenK << sep;
641- problemOS << " -head_dim_qk " << headDimQK << sep;
642- problemOS << " -head_dim_v " << headDimV;
650+ if (isAttention) {
651+ problemOS << " -seq_len_q " << seqLenQ << sep;
652+ problemOS << " -seq_len_k " << seqLenK << sep;
653+ problemOS << " -head_dim_qk " << headDimQK << sep;
654+ problemOS << " -head_dim_v " << headDimV;
655+ } else {
656+ problemOS << " -m " << seqLenQ << sep;
657+ problemOS << " -n " << seqLenK << sep;
658+ problemOS << " -k " << headDimQK << sep;
659+ problemOS << " -gemmO " << headDimV;
660+ }
643661 return success ();
644662}
645663
0 commit comments