Skip to content

Commit abba872

Browse files
authored
Fix generateMlirDriverCommandLine for attention in perfRunner
Fix generateMlirDriverCommandLine for attention in perfRunner
2 parents 078a100 + 940760f commit abba872

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

mlir/utils/performance/perfRunner.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
OUTPUT_DATA_TYPES_MAP = {'f32': 'f32', 'f16': 'f16', 'bf16': 'bf16', 'i8': 'i32', 'fp8':'f32',
3838
'fp8_fp8': 'f32', 'fp8_bf8': 'f32', 'bf8_fp8': 'f32',
3939
'bf8_bf8': 'f32'}
40+
MLIR_N_REPEATS = 5
4041

4142
# Compiled regexp object used for extracting elapsed time from MIOpenDriver's output
4243
ELAPSED_TIME_RE = re.compile(r"Elapsed: ([0-9\.]*) ms")
@@ -367,7 +368,7 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags):
367368
'--padding_h', str(self.paddingH),
368369
'--padding_w', str(self.paddingW),
369370
'--groupsize', str(self.group),
370-
'--kernel-repeats', str(self.MLIR_N_REPEATS),
371+
'--kernel-repeats', str(MLIR_N_REPEATS),
371372
f"--perf_config={self.perfConfig}"])
372373
result += ' '
373374
if rocmlir_gen_flags != '':
@@ -496,7 +497,6 @@ def __init__(self, dtype: str, direction: str, filterLayout: str, inputLayout:st
496497
if direction not in {"fwd", "bwd", "wrw"}:
497498
raise ValueError(f"Invalid direction: {direction}")
498499

499-
self.MLIR_N_REPEATS = 5
500500
self.dataType = dtype
501501
self.direction = direction
502502

@@ -666,7 +666,7 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags):
666666
'-n', str(self.n),
667667
f"-transA={self.transA}",
668668
f"-transB={self.transB}",
669-
'--kernel-repeats', str(self.MLIR_N_REPEATS),
669+
'--kernel-repeats', str(MLIR_N_REPEATS),
670670
f"--perf_config={self.perfConfig}"])
671671

672672
result += ' '
@@ -723,7 +723,6 @@ def __init__(self, dtype: str, outDataType: str, g: int, m: int, k: int, n: int,
723723
transA: bool, transB: bool, arch: str, numCU: int, perf_config: str = ''):
724724
if dtype not in {"f16", "f32", "bf16", "i8", "fp8"}:
725725
raise ValueError(f"Invalid datatype: {dtype}")
726-
self.MLIR_N_REPEATS = 5
727726
self.dataType = dtype
728727
self.outDataType = outDataType
729728
self.g = g
@@ -759,7 +758,6 @@ def __init__(self, dtype: str, g: int, seq_len_q: int, seq_len_k: int, head_dim_
759758
self.arch = arch
760759
self.chip = GFX_CHIP_RE.search(arch).group(0)
761760
self.numCU = numCU
762-
self.MLIR_N_REPEATS = 5
763761
self.perfConfig = perf_config
764762

765763
def computeTFlops(self, ns, only_matmul_flops=True):
@@ -826,9 +824,9 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags):
826824
f"-with-attn-scale={self.with_attn_scale}",
827825
f"-transQ={self.transQ}",
828826
f"-transK={self.transK}",
829-
f"-transQ={self.transV}",
830-
f"-transK={self.transO}",
831-
'--kernel-repeats', str(self.MLIR_N_REPEATS),
827+
f"-transV={self.transV}",
828+
f"-transO={self.transO}",
829+
'--kernel-repeats', str(MLIR_N_REPEATS),
832830
f"--perf_config={self.perfConfig}"])
833831
result += ' '
834832
if rocmlir_gen_flags != '':

0 commit comments

Comments
 (0)