Skip to content

Commit baf96d9

Browse files
dorde-anticCopilot
andauthored
Define datatypes for attention dynamically based on chip type (#1894)
Set DATA_TYPES_ATTENTION dynamically based on chip type in perfRunner --------- Signed-off-by: Djordje Antic <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent b3732c6 commit baf96d9

File tree

1 file changed

+42
-29
lines changed

1 file changed

+42
-29
lines changed

mlir/utils/performance/perfRunner.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
LAYOUTS = ['NHWC', 'NCHW']
3636

3737
DATA_TYPES_GEMM = ['f32', 'f16', 'bf16', 'i8', 'fp8']
38-
DATA_TYPES_ATTENTION = ['i8', 'f32', 'f16', 'bf16']
38+
DATA_TYPES_ATTENTION_WMMA = ['i8', 'f16', 'bf16']
39+
DATA_TYPES_ATTENTION_MFMA = ['i8', 'f32', 'f16', 'bf16']
3940
DATA_TYPES_GEMM_GEMM = ['f32', 'f16', 'bf16']
4041
DATA_TYPES_CONV_GEMM = ['f32', 'f16', 'bf16']
4142
OUTPUT_DATA_TYPES_MAP = {'f32': 'f32', 'f16': 'f16', 'bf16': 'bf16', 'i8': 'i32', 'fp8':'f32',
@@ -117,6 +118,43 @@ def find_mlir_build_dir() -> str:
117118
build_dir = Path(rocmlir_gen_path).parent.parent
118119
return str(build_dir)
119120

121+
def hip_check(call_result):
122+
err = call_result[0]
123+
result = call_result[1:]
124+
if len(result) == 1:
125+
result = result[0]
126+
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
127+
raise RuntimeError(str(err))
128+
return result
129+
130+
def getArch() -> str:
131+
agents = set()
132+
device_count = hip_check(hip.hipGetDeviceCount())
133+
for device in range(device_count):
134+
props = hip.hipDeviceProp_t()
135+
hip_check(hip.hipGetDeviceProperties(props,device))
136+
agent = props.gcnArchName.decode('utf-8')
137+
agents.add(agent)
138+
if(len(agents) > 1):
139+
print(f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}")
140+
print("WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable.")
141+
# select first agent by default
142+
return list(agents)[0]
143+
144+
def getChip():
145+
arch = getArch()
146+
chip = GFX_CHIP_RE.search(arch).group(0)
147+
return chip
148+
149+
DATA_TYPES_ATTENTION = None
150+
151+
def initializeDataTypesAttention():
152+
global DATA_TYPES_ATTENTION
153+
if getChip().startswith('gfx9'):
154+
DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_MFMA
155+
else:
156+
DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA
157+
120158
def create_paths(config_file_path, mlir_build_dir_path) -> Paths:
121159
"""Creates the composite Paths structure using build dir paths"""
122160

@@ -686,6 +724,8 @@ def getGemmGemmConfigurations(fileName):
686724
return configs
687725

688726
def getAttentionConfigurations(fileName):
727+
if DATA_TYPES_ATTENTION is None:
728+
initializeDataTypesAttention()
689729
bool_space = ['false', 'true']
690730
default_test_space = {
691731
"-t": DATA_TYPES_ATTENTION,
@@ -1740,29 +1780,6 @@ def tuneMLIRKernels(configs, arch, numCU):
17401780
print("MIOpen tuning timed out")
17411781
_, errs = p1.communicate()
17421782

1743-
def hip_check(call_result):
1744-
err = call_result[0]
1745-
result = call_result[1:]
1746-
if len(result) == 1:
1747-
result = result[0]
1748-
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
1749-
raise RuntimeError(str(err))
1750-
return result
1751-
1752-
def getArch() -> str:
1753-
agents = set()
1754-
device_count = hip_check(hip.hipGetDeviceCount())
1755-
for device in range(device_count):
1756-
props = hip.hipDeviceProp_t()
1757-
hip_check(hip.hipGetDeviceProperties(props,device))
1758-
agent = props.gcnArchName.decode('utf-8')
1759-
agents.add(agent)
1760-
if(len(agents) > 1):
1761-
print(f"WARNING: Found {len(agents)} different kinds of agents on the same machine : {', '.join(agents)}")
1762-
print("WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable.")
1763-
# select first agent by default
1764-
return list(agents)[0]
1765-
17661783
def parseDataTypes(data_types):
17671784
if not data_types:
17681785
return DATA_TYPES_GEMM, OUTPUT_DATA_TYPES_MAP
@@ -1780,11 +1797,6 @@ def parseDataTypes(data_types):
17801797
outMap[dt[0]] = 'f32'
17811798
return datatypes, outMap
17821799

1783-
def getChip():
1784-
arch = getArch()
1785-
chip = GFX_CHIP_RE.search(arch).group(0)
1786-
return chip
1787-
17881800
def getNumCU(chip):
17891801
try:
17901802
rocminfo = subprocess.check_output("/opt/rocm/bin/rocminfo",
@@ -1842,6 +1854,7 @@ def main(args=None):
18421854
arch = getArch()
18431855
chip = getChip()
18441856
numCU = getNumCU(chip)
1857+
initializeDataTypesAttention()
18451858

18461859
root_dir = str(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode().strip())
18471860
default_conv_configs = root_dir + '/mlir/utils/jenkins/performance/configs/tier1-conv-configs'

0 commit comments

Comments
 (0)