|
37 | 37 | OUTPUT_DATA_TYPES_MAP = {'f32': 'f32', 'f16': 'f16', 'bf16': 'bf16', 'i8': 'i32', 'fp8':'f32', |
38 | 38 | 'fp8_fp8': 'f32', 'fp8_bf8': 'f32', 'bf8_fp8': 'f32', |
39 | 39 | 'bf8_bf8': 'f32'} |
| 40 | +MLIR_N_REPEATS = 5 |
40 | 41 |
|
41 | 42 | # Compiled regexp object used for extracting elapsed time from MIOpenDriver's output |
42 | 43 | ELAPSED_TIME_RE = re.compile(r"Elapsed: ([0-9\.]*) ms") |
@@ -367,7 +368,7 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags): |
367 | 368 | '--padding_h', str(self.paddingH), |
368 | 369 | '--padding_w', str(self.paddingW), |
369 | 370 | '--groupsize', str(self.group), |
370 | | - '--kernel-repeats', str(self.MLIR_N_REPEATS), |
| 371 | + '--kernel-repeats', str(MLIR_N_REPEATS), |
371 | 372 | f"--perf_config={self.perfConfig}"]) |
372 | 373 | result += ' ' |
373 | 374 | if rocmlir_gen_flags != '': |
@@ -496,7 +497,6 @@ def __init__(self, dtype: str, direction: str, filterLayout: str, inputLayout:st |
496 | 497 | if direction not in {"fwd", "bwd", "wrw"}: |
497 | 498 | raise ValueError(f"Invalid direction: {direction}") |
498 | 499 |
|
499 | | - self.MLIR_N_REPEATS = 5 |
500 | 500 | self.dataType = dtype |
501 | 501 | self.direction = direction |
502 | 502 |
|
@@ -666,7 +666,7 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags): |
666 | 666 | '-n', str(self.n), |
667 | 667 | f"-transA={self.transA}", |
668 | 668 | f"-transB={self.transB}", |
669 | | - '--kernel-repeats', str(self.MLIR_N_REPEATS), |
| 669 | + '--kernel-repeats', str(MLIR_N_REPEATS), |
670 | 670 | f"--perf_config={self.perfConfig}"]) |
671 | 671 |
|
672 | 672 | result += ' ' |
@@ -723,7 +723,6 @@ def __init__(self, dtype: str, outDataType: str, g: int, m: int, k: int, n: int, |
723 | 723 | transA: bool, transB: bool, arch: str, numCU: int, perf_config: str = ''): |
724 | 724 | if dtype not in {"f16", "f32", "bf16", "i8", "fp8"}: |
725 | 725 | raise ValueError(f"Invalid datatype: {dtype}") |
726 | | - self.MLIR_N_REPEATS = 5 |
727 | 726 | self.dataType = dtype |
728 | 727 | self.outDataType = outDataType |
729 | 728 | self.g = g |
@@ -759,7 +758,6 @@ def __init__(self, dtype: str, g: int, seq_len_q: int, seq_len_k: int, head_dim_ |
759 | 758 | self.arch = arch |
760 | 759 | self.chip = GFX_CHIP_RE.search(arch).group(0) |
761 | 760 | self.numCU = numCU |
762 | | - self.MLIR_N_REPEATS = 5 |
763 | 761 | self.perfConfig = perf_config |
764 | 762 |
|
765 | 763 | def computeTFlops(self, ns, only_matmul_flops=True): |
@@ -826,9 +824,9 @@ def generateMlirDriverCommandLine(self, rocmlir_gen_flags): |
826 | 824 | f"-with-attn-scale={self.with_attn_scale}", |
827 | 825 | f"-transQ={self.transQ}", |
828 | 826 | f"-transK={self.transK}", |
829 | | - f"-transQ={self.transV}", |
830 | | - f"-transK={self.transO}", |
831 | | - '--kernel-repeats', str(self.MLIR_N_REPEATS), |
| 827 | + f"-transV={self.transV}", |
| 828 | + f"-transO={self.transO}", |
| 829 | + '--kernel-repeats', str(MLIR_N_REPEATS), |
832 | 830 | f"--perf_config={self.perfConfig}"]) |
833 | 831 | result += ' ' |
834 | 832 | if rocmlir_gen_flags != '': |
|
0 commit comments