2727 triton .Config (
2828 {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 32 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 0 },
2929 num_warps = 8 , num_stages = 2 ),
30- triton .Config (
31- {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 16 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 },
32- num_warps = 4 , num_stages = 2 ),
3330 triton .Config (
3431 {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 , 'GROUP_SIZE_M' : 1 , 'waves_per_eu' : 2 },
3532 num_warps = 8 , num_stages = 2 ),
36- triton .Config (
37- {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 , 'GROUP_SIZE_M' : 32 , 'waves_per_eu' : 2 },
38- num_warps = 4 , num_stages = 2 ),
3933 ],
4034 key = ['M' , 'N' , 'K' ],
4135 use_cuda_graph = True ,
@@ -122,7 +116,7 @@ def matmul_kernel(
122116 else :
123117 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
124118 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
125- accumulator += tl .dot (a , b )
119+ accumulator += tl .dot (a , b , input_precision = "ieee" )
126120
127121 # Advance the ptrs to the next K block.
128122 a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -179,29 +173,36 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
179173 )
180174
181175
176+ def is_cdna4 ():
177+ return triton .runtime .driver .active .get_current_target ().arch == 'gfx950'
178+
179+
180+ e5m2_type = torch .float8_e5m2 if is_cdna4 () else torch .float8_e5m2fnuz
181+ e4m3_type = torch .float8_e4m3fn if is_cdna4 () else torch .float8_e4m3fnuz
182+
182183name_to_torch_types = {
183184 'int8' : torch .int8 ,
184185 'int32' : torch .int32 ,
185186 'fp16' : torch .float16 ,
186187 'fp32' : torch .float32 ,
187188 'bf16' : torch .bfloat16 ,
188- 'fp8e5' : torch . float8_e5m2fnuz ,
189- 'fp8e4' : torch . float8_e4m3fnuz ,
189+ 'fp8e5' : e5m2_type ,
190+ 'fp8e4' : e4m3_type ,
190191}
191192
192193dtype_max = {
193194 dtype : (torch .finfo (dtype ) if dtype .is_floating_point else torch .iinfo (dtype )).max
194195 for dtype in [
195- torch . float8_e5m2fnuz ,
196- torch . float8_e4m3fnuz ,
196+ e5m2_type ,
197+ e4m3_type ,
197198 torch .int8 ,
198199 ]
199200}
200201
201202
202203def dtype_is_8_bit (dtype ):
203- return (dtype is torch . float8_e5m2fnuz ) or \
204- (dtype is torch . float8_e4m3fnuz ) or \
204+ return (dtype is e5m2_type ) or \
205+ (dtype is e4m3_type ) or \
205206 (dtype is torch .int8 )
206207
207208
@@ -278,7 +279,7 @@ def get_type(provider):
278279 x_vals = get_x_vals (),
279280 line_arg = 'provider' ,
280281 line_vals = [
281- 'rocblas (fp16)' , 'rocblas (bf16)' , 'triton(fp16)' , 'triton(bf16)' , 'triton(int8)' , 'triton(fp8e4)' ,
282+ 'hipblaslt (fp16)' , 'hipblaslt (bf16)' , 'triton(fp16)' , 'triton(bf16)' , 'triton(int8)' , 'triton(fp8e4)' ,
282283 'triton(fp8e5)'
283284 ],
284285 line_names = [
@@ -293,9 +294,10 @@ def benchmark(M, N, K, provider, model=None):
293294 out_dtype = in_dtype
294295
295296 quantiles = [0.5 , 0.2 , 0.8 ]
296- if 'rocblas ' in provider :
297+ if 'hipblaslt ' in provider :
297298 a = torch .randn ((M , K ), dtype = in_dtype , device = 'cuda' )
298- b = torch .randn ((K , N ), dtype = in_dtype , device = 'cuda' )
299+ b = torch .randn ((N , K ), dtype = in_dtype , device = 'cuda' )
300+ b = b .T
299301
300302 ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
301303 else : # triton, different data types
0 commit comments