1212from pathlib import Path
1313
1414
15- def min_dot_size (target : GPUTarget ):
16- # If some given configuration is not supported in hardware we fallback to FMA and cast arguments
17- return lambda lhsType , rhsType : (1 , 1 , 1 )
15+ def get_min_dot_size (target : GPUTarget ):
16+ # We fallback to use FMA and cast arguments if certain configurations is
17+ # not supported natively by matrix core units.
18+ return lambda lhs_type , rhs_type : (1 , 1 , 1 )
1819
1920
20- def is_pingpong_enabled (arch ):
21+ def is_pingpong_schedule_enabled (arch ):
2122 default = "1" if arch == "gfx942" else "0"
2223 return os .getenv ("TRITON_HIP_USE_BLOCK_PINGPONG" , default ) == "1"
2324
@@ -68,27 +69,29 @@ class HIPOptions:
6869 schedule_hint : str = 'none'
6970
7071 def __post_init__ (self ):
71- default_libdir = Path (__file__ ).parent / 'lib'
72- extern_libs = {} if self .extern_libs is None else dict (self .extern_libs )
73- # Ignore user-defined warp size for gfx9
74- warp_size = 32 if 'gfx10' in self .arch or 'gfx11' in self .arch or 'gfx12' in self .arch else 64
72+ gfx_major = int (self .arch [3 :- 2 ]) # Drop "gfx" prefix and minor/patch number
73+ warp_size = 32 if gfx_major >= 10 else 64
7574 object .__setattr__ (self , 'warp_size' , warp_size )
75+
7676 # Error out if max threads per block is exceeded.
7777 # This is theoretically architecture specific but in reality they are all 1024.
7878 max_threads = 1024
7979 assert self .num_warps * warp_size <= max_threads , \
8080 f"{ self .num_warps } warps * { warp_size } warp size" \
8181 f" must not exceed the max threads per block limit ({ max_threads } )"
82- # Only kpack=1 is supported on gfx950
83- kpack = 1 if self .arch == 'gfx950' else self .kpack
84- object .__setattr__ (self , 'kpack' , kpack )
85- libs = ["ocml" , "ockl" ]
86- for lib in libs :
87- extern_libs [lib ] = str (default_libdir / f'{ lib } .bc' )
88- object .__setattr__ (self , 'extern_libs' , tuple (extern_libs .items ()))
82+
8983 assert self .num_warps > 0 and (self .num_warps & (self .num_warps - 1 )) == 0 , \
9084 "num_warps must be a power of 2"
9185
86+ if self .arch == 'gfx950' :
87+ assert self .kpack == 1 , "gfx950 only accepts kpack == 1"
88+
89+ default_libdir = Path (__file__ ).parent / 'lib'
90+ extern_libs = {} if self .extern_libs is None else dict (self .extern_libs )
91+ for lib in ["ocml" , "ockl" ]:
92+ extern_libs [lib ] = str (default_libdir / f'{ lib } .bc' )
93+ object .__setattr__ (self , 'extern_libs' , tuple (extern_libs .items ()))
94+
9295 def hash (self ):
9396 key = '_' .join ([f'{ name } -{ val } ' for name , val in self .__dict__ .items ()])
9497 return hashlib .sha256 (key .encode ("utf-8" )).hexdigest ()
@@ -109,22 +112,23 @@ def parse_options(self, opts) -> Any:
109112 args = {'arch' : os .getenv ("TRITON_OVERRIDE_ARCH" , self .target .arch )}
110113
111114 # Enable XF32 (TF32) for CDNA3 GPUs
112- if self .target .arch in ( 'gfx940' , 'gfx941' , ' gfx942') :
115+ if self .target .arch == ' gfx942' :
113116 allowed_dot_input_precisions = set (HIPOptions .allowed_dot_input_precisions )
114117 allowed_dot_input_precisions .update ({'tf32' })
115118 args ["allowed_dot_input_precisions" ] = tuple (sorted (allowed_dot_input_precisions ))
116119
117120 if "supported_fp8_dtypes" not in opts :
118121 supported_fp8_dtypes = set (HIPOptions .supported_fp8_dtypes )
119- if self .target .arch in ( 'gfx940' , 'gfx941' , ' gfx942') :
122+ if self .target .arch == ' gfx942' :
120123 supported_fp8_dtypes .update ({'fp8e4nv' , 'fp8e4b8' , 'fp8e5b16' })
121- elif self .target .arch in ( 'gfx950' ) :
124+ elif self .target .arch == 'gfx950' :
122125 supported_fp8_dtypes .update ({'fp8e4nv' , 'fp8e5' })
123126 args ["supported_fp8_dtypes" ] = tuple (sorted (supported_fp8_dtypes ))
124127
125128 if "enable_fp_fusion" not in opts :
126129 args ["enable_fp_fusion" ] = os .getenv ("TRITON_DEFAULT_FP_FUSION" , "1" ) == "1"
127- args .update ({k : opts [k ] for k in HIPOptions .__dataclass_fields__ .keys () if k in opts and opts [k ] is not None })
130+ args .update ({k : opts [k ] for k in HIPOptions .__dataclass_fields__ .keys () \
131+ if k in opts and opts [k ] is not None })
128132 return HIPOptions (** args )
129133
130134 def pack_metadata (self , metadata ):
@@ -138,8 +142,7 @@ def pack_metadata(self, metadata):
138142 )
139143
140144 def get_codegen_implementation (self , options ):
141- codegen_fns = {"min_dot_size" : min_dot_size (self .target )}
142- return codegen_fns
145+ return {"min_dot_size" : get_min_dot_size (self .target )}
143146
144147 def get_module_map (self ) -> Dict [str , ModuleType ]:
145148 from triton .language .extra .hip import libdevice
@@ -248,17 +251,10 @@ def make_ttgir(mod, metadata, options):
248251 if options .schedule_hint == "local-prefetch" :
249252 global_prefetch = local_prefetch = 1
250253
251- if amd .has_matrix_core_feature (options .arch ):
252- assert options .num_stages != 0 , ("Triton AMD backend pipeliner has been updated. "
253- "We used to trigger software pipelining with "
254- "num_stages == 0. Now it will not happen anymore; "
255- "please update to use num_stages == 2 for "
256- "equivalent behavior in the past." )
257- amd .passes .ttgpuir .add_stream_pipeline (pm , options .num_stages , global_prefetch , local_prefetch ,
258- use_async_copy )
259- if use_async_copy :
260- amd .passes .ttgpuir .add_coalesce_async_copy (pm , options .arch )
261- passes .common .add_canonicalizer (pm )
254+ amd .passes .ttgpuir .add_stream_pipeline (pm , options .num_stages , global_prefetch , local_prefetch , use_async_copy )
255+ if use_async_copy :
256+ amd .passes .ttgpuir .add_coalesce_async_copy (pm , options .arch )
257+ passes .common .add_canonicalizer (pm )
262258 if options .schedule_hint .lower () != "none" :
263259 amd .passes .ttgpuir .insert_instruction_sched_hints (pm , options .schedule_hint )
264260 passes .ttgpuir .add_optimize_dot_operands (pm , True )
@@ -267,16 +263,16 @@ def make_ttgir(mod, metadata, options):
267263 if is_in_thread_transpose_enabled (options .arch ):
268264 amd .passes .ttgpuir .add_in_thread_transpose (pm )
269265 passes .ttgpuir .add_remove_layout_conversions (pm )
270- if amd .has_matrix_core_feature (options .arch ):
271- amd .passes .ttgpuir .add_reorder_instructions (pm )
272- use_block_pingpong = is_pingpong_enabled (options .arch )
273- if use_block_pingpong and options .num_stages == 2 :
274- amd .passes .ttgpuir .add_block_pingpong (pm , options .num_stages )
266+ amd .passes .ttgpuir .add_reorder_instructions (pm )
267+ use_block_pingpong = is_pingpong_schedule_enabled (options .arch )
268+ if use_block_pingpong and options .num_stages == 2 :
269+ amd .passes .ttgpuir .add_block_pingpong (pm , options .num_stages )
275270
276271 if HIPBackend .use_buffer_ops ():
277272 amd .passes .ttgpuir .add_canonicalize_pointers (pm )
278273 passes .common .add_canonicalizer (pm )
279274 amd .passes .ttgpuir .add_convert_to_buffer_ops (pm , options .arch )
275+
280276 amd .passes .ttgpuir .add_fold_true_cmpi (pm )
281277 passes .common .add_canonicalizer (pm )
282278 passes .common .add_cse (pm )
0 commit comments