3535LAYOUTS = ['NHWC' , 'NCHW' ]
3636
3737DATA_TYPES_GEMM = ['f32' , 'f16' , 'bf16' , 'i8' , 'fp8' ]
38- DATA_TYPES_ATTENTION = ['i8' , 'f32' , 'f16' , 'bf16' ]
38+ DATA_TYPES_ATTENTION_WMMA = ['i8' , 'f16' , 'bf16' ]
39+ DATA_TYPES_ATTENTION_MFMA = ['i8' , 'f32' , 'f16' , 'bf16' ]
3940DATA_TYPES_GEMM_GEMM = ['f32' , 'f16' , 'bf16' ]
4041DATA_TYPES_CONV_GEMM = ['f32' , 'f16' , 'bf16' ]
4142OUTPUT_DATA_TYPES_MAP = {'f32' : 'f32' , 'f16' : 'f16' , 'bf16' : 'bf16' , 'i8' : 'i32' , 'fp8' :'f32' ,
@@ -117,6 +118,43 @@ def find_mlir_build_dir() -> str:
117118 build_dir = Path (rocmlir_gen_path ).parent .parent
118119 return str (build_dir )
119120
121+ def hip_check (call_result ):
122+ err = call_result [0 ]
123+ result = call_result [1 :]
124+ if len (result ) == 1 :
125+ result = result [0 ]
126+ if isinstance (err , hip .hipError_t ) and err != hip .hipError_t .hipSuccess :
127+ raise RuntimeError (str (err ))
128+ return result
129+
130+ def getArch () -> str :
131+ agents = set ()
132+ device_count = hip_check (hip .hipGetDeviceCount ())
133+ for device in range (device_count ):
134+ props = hip .hipDeviceProp_t ()
135+ hip_check (hip .hipGetDeviceProperties (props ,device ))
136+ agent = props .gcnArchName .decode ('utf-8' )
137+ agents .add (agent )
138+ if (len (agents ) > 1 ):
139+ print (f"WARNING: Found { len (agents )} different kinds of agents on the same machine : { ', ' .join (agents )} " )
140+ print ("WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable." )
141+ # select first agent by default
142+ return list (agents )[0 ]
143+
144+ def getChip ():
145+ arch = getArch ()
146+ chip = GFX_CHIP_RE .search (arch ).group (0 )
147+ return chip
148+
149+ DATA_TYPES_ATTENTION = None
150+
151+ def initializeDataTypesAttention ():
152+ global DATA_TYPES_ATTENTION
153+ if getChip ().startswith ('gfx9' ):
154+ DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_MFMA
155+ else :
156+ DATA_TYPES_ATTENTION = DATA_TYPES_ATTENTION_WMMA
157+
120158def create_paths (config_file_path , mlir_build_dir_path ) -> Paths :
121159 """Creates the composite Paths structure using build dir paths"""
122160
@@ -686,6 +724,8 @@ def getGemmGemmConfigurations(fileName):
686724 return configs
687725
688726def getAttentionConfigurations (fileName ):
727+ if DATA_TYPES_ATTENTION is None :
728+ initializeDataTypesAttention ()
689729 bool_space = ['false' , 'true' ]
690730 default_test_space = {
691731 "-t" : DATA_TYPES_ATTENTION ,
@@ -1740,29 +1780,6 @@ def tuneMLIRKernels(configs, arch, numCU):
17401780 print ("MIOpen tuning timed out" )
17411781 _ , errs = p1 .communicate ()
17421782
1743- def hip_check (call_result ):
1744- err = call_result [0 ]
1745- result = call_result [1 :]
1746- if len (result ) == 1 :
1747- result = result [0 ]
1748- if isinstance (err , hip .hipError_t ) and err != hip .hipError_t .hipSuccess :
1749- raise RuntimeError (str (err ))
1750- return result
1751-
1752- def getArch () -> str :
1753- agents = set ()
1754- device_count = hip_check (hip .hipGetDeviceCount ())
1755- for device in range (device_count ):
1756- props = hip .hipDeviceProp_t ()
1757- hip_check (hip .hipGetDeviceProperties (props ,device ))
1758- agent = props .gcnArchName .decode ('utf-8' )
1759- agents .add (agent )
1760- if (len (agents ) > 1 ):
1761- print (f"WARNING: Found { len (agents )} different kinds of agents on the same machine : { ', ' .join (agents )} " )
1762- print ("WARNING: Using the first agent by default. If you want to use a different agent, please set the HIP_VISIBLE_DEVICES environment variable." )
1763- # select first agent by default
1764- return list (agents )[0 ]
1765-
17661783def parseDataTypes (data_types ):
17671784 if not data_types :
17681785 return DATA_TYPES_GEMM , OUTPUT_DATA_TYPES_MAP
@@ -1780,11 +1797,6 @@ def parseDataTypes(data_types):
17801797 outMap [dt [0 ]] = 'f32'
17811798 return datatypes , outMap
17821799
1783- def getChip ():
1784- arch = getArch ()
1785- chip = GFX_CHIP_RE .search (arch ).group (0 )
1786- return chip
1787-
17881800def getNumCU (chip ):
17891801 try :
17901802 rocminfo = subprocess .check_output ("/opt/rocm/bin/rocminfo" ,
@@ -1842,6 +1854,7 @@ def main(args=None):
18421854 arch = getArch ()
18431855 chip = getChip ()
18441856 numCU = getNumCU (chip )
1857+ initializeDataTypesAttention ()
18451858
18461859 root_dir = str (subprocess .check_output (['git' , 'rev-parse' , '--show-toplevel' ]).decode ().strip ())
18471860 default_conv_configs = root_dir + '/mlir/utils/jenkins/performance/configs/tier1-conv-configs'
0 commit comments