1111
1212@triton .autotune (
1313 configs = [
14+ triton .Config (
15+ {
16+ 'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 ,
17+ 'kpack' : 2 , 'matrix_instr_nonkdim' : 16
18+ }, num_warps = 8 , num_stages = 2 ),
1419 triton .Config (
1520 {'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 128 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 0 },
1621 num_warps = 8 , num_stages = 2 ),
1722 triton .Config (
1823 {
19- 'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 8 , 'waves_per_eu' : 2 ,
20- 'kpack' : 2 , 'matrix_instr_nonkdim' : 16
24+ 'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 ,
25+ 'kpack' : 1 , 'matrix_instr_nonkdim' : 16
2126 }, num_warps = 8 , num_stages = 2 ),
2227 triton .Config (
2328 {
@@ -128,7 +133,7 @@ def matmul_kernel(
128133 a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
129134 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
130135 if APPLY_SCALE :
131- a_scale = tl .load (a_scale_ptr )
136+ a_scale = tl .load (a_scale_ptr ) if ( a_scale_ptr ) else 1.0
132137 b_scale = tl .load (b_scale_ptr )
133138
134139 acc_dtype = tl .float32 if c_ptr .type .element_ty != tl .int8 else tl .int32
@@ -143,6 +148,8 @@ def matmul_kernel(
143148 else :
144149 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
145150 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
151+ # Type conversion to support mixed precision GEMMs where b is lower precision than a
152+ b = b .to (a_ptr .type .element_ty )
146153 accumulator += tl .dot (a , b , input_precision = "ieee" )
147154
148155 # Advance the ptrs to the next K block.
@@ -176,7 +183,10 @@ def leaky_relu(x):
176183def matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = False , activation = "" ):
177184 # Check constraints.
178185 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions!!!"
179- assert a .dtype == b .dtype , "Mixed dtype GEMMs are not supported!!!"
186+ assert (a .element_size ()
187+ >= b .element_size ()), "Mixed dtype GEMMs are only supported when data type of a is bigger than b!!!"
188+ assert (a .is_floating_point () == b .is_floating_point ()
189+ ), "GEMMs between float and integer type tensors are not supported!!!"
180190 M , K = a .shape
181191 K , N = b .shape
182192 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
@@ -262,32 +272,39 @@ def get_x_vals():
262272
263273# Unit tests
264274#TODO(vgokhale): Test activation.
275+ # yapf: disable
265276@pytest .mark .parametrize (
266- "M, N, K, in_dtype , out_dtype, col_a, col_b" ,
267- [(* shape , in_dtype , out_dtype , col_a , col_b )
277+ "M, N, K, in_dtype_a, in_dtype_b , out_dtype, col_a, col_b" ,
278+ [(* shape , in_dtype_a , in_dtype_b , out_dtype , col_a , col_b )
268279 for shape in get_x_vals ()
269- for in_dtype , out_dtype in [('fp16' , 'fp16' ), ('bf16' , 'bf16' ), ('fp32' , 'fp32' ), (
270- 'fp8e4' , 'fp16' ), ('fp8e5' , 'fp16' ), ('int8' , 'int8' ), ('int8' , 'int32' )]
280+ for in_dtype_a , in_dtype_b , out_dtype in [
281+ ('fp16' , 'fp16' , 'fp16' ), ('bf16' , 'bf16' , 'bf16' ), ('fp32' , 'fp32' , 'fp32' ),
282+ ('fp8e4' , 'fp8e4' , 'fp16' ), ('fp8e5' , 'fp8e5' , 'fp16' ), ('fp16' , 'fp8e4' , 'fp16' ),
283+ ('fp16' , 'fp8e5' , 'fp16' ), ('bf16' , 'fp8e4' , 'bf16' ), ('bf16' , 'fp8e5' , 'bf16' ),
284+ ('int8' , 'int8' , 'int8' ), ('int8' , 'int8' , 'int32' )]
271285 # Defines if a matrix is row or column major.
272286 for col_a in [True , False ]
273287 for col_b in [True , False ]])
274- def test_correctness (M , N , K , col_a , col_b , in_dtype , out_dtype ):
275- torch_in_dtype = name_to_torch_types [in_dtype ]
276- a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype , col_a , 1 , device = 'cuda' )
277- b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype , col_b , 2 , device = 'cuda' )
288+ # yapf: enable
289+ def test_correctness (M , N , K , col_a , col_b , in_dtype_a , in_dtype_b , out_dtype ):
290+ torch_in_dtype_a = name_to_torch_types [in_dtype_a ]
291+ torch_in_dtype_b = name_to_torch_types [in_dtype_b ]
292+ a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , 1 , device = 'cuda' )
293+ b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , 2 , device = 'cuda' )
278294 torch_out_dtype = name_to_torch_types [out_dtype ]
279295 c = torch .empty ((M , N ), device = a .device , dtype = torch_out_dtype )
280296 # For 8-bit, we have scaled to the dynamic range of the data type.
281297 # This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
282298 # If we use fp16 it is possible to return infs from the torch.matmul call.
283- if dtype_is_8_bit (torch_in_dtype ):
299+ if dtype_is_8_bit (torch_in_dtype_a ) or dtype_is_8_bit ( torch_in_dtype_b ):
284300 matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = True , activation = "" )
285301 torch_output = torch .matmul (a_fp32 , b_fp32 )
286- torch_output = torch_output * a_scale * b_scale
302+ # Set a_scale to 1.0 if it is not set
303+ torch_output = torch_output * (a_scale or 1.0 ) * b_scale
287304 # For other dtypes, use the same torch matmul as the dtype.
288305 else :
289306 matmul (a , b , c , a_scale = None , b_scale = None , scale_a8_b8 = False , activation = "" )
290- torch_output = torch .matmul (a .to (torch_in_dtype ), b .to (torch_in_dtype ))
307+ torch_output = torch .matmul (a .to (torch_in_dtype_a ), b .to (torch_in_dtype_b ))
291308 if out_dtype == 'int8' :
292309 torch .testing .assert_close (c .to (torch .float32 ),
293310 torch_output .to (torch .int8 ).to (torch .float32 ), atol = 1e-3 , rtol = 1e-2 )
@@ -297,7 +314,7 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
297314
298315def get_type (provider ):
299316 res = re .findall (r'\(.*?\)' , provider )
300- return res [0 ][1 :- 1 ]
317+ return res [0 ][1 :- 1 ]. split ( '/' , 1 )
301318
302319
303320@triton .testing .perf_report (
@@ -306,39 +323,38 @@ def get_type(provider):
306323 x_vals = get_x_vals (),
307324 line_arg = 'provider' ,
308325 line_vals = [
309- 'hipblaslt(fp16)' , 'hipblaslt(bf16)' , 'triton(fp16)' , 'triton(bf16)' , 'triton(int8)' , 'triton(fp8e4)' ,
310- 'triton(fp8e5)'
326+ 'hipblaslt(fp16/fp16)' , 'hipblaslt(bf16/bf16)' , 'triton(fp16/fp16)' , 'triton(bf16/bf16)' ,
327+ 'triton(int8/int8)' , 'triton(fp8e4/fp8e4)' , 'triton(fp8e5/fp8e5)' , 'triton(fp16/fp8e4)' ,
328+ 'triton(fp16/fp8e5)'
311329 ],
312330 line_names = [
313- "rocBLAS.Fp16" , "rocBLAS.Bf16" , "Triton.Fp16" , "Triton.Bf16" , "Triton.Int8" , "Triton.Fp8E4" , "Triton.Fp8E5"
331+ "rocBLAS.Fp16" , "rocBLAS.Bf16" , "Triton.Fp16" , "Triton.Bf16" , "Triton.Int8" , "Triton.Fp8E4" , "Triton.Fp8E5" ,
332+ "Triton.Fp16.Fp8E4" , "Triton.Fp16.Fp8E5"
314333 ],
315334 ylabel = "TFLOPS" ,
316335 plot_name = "matmul-performance" ,
317336 args = {},
318337 ))
319338def benchmark (M , N , K , provider , model = None ):
320- in_dtype = name_to_torch_types [get_type (provider )]
321- out_dtype = in_dtype
339+ in_dtype_a , in_dtype_b = [ name_to_torch_types [x ] for x in get_type (provider )]
340+ out_dtype = in_dtype_a
322341
323342 quantiles = [0.5 , 0.2 , 0.8 ]
324343 if 'hipblaslt' in provider :
325- a = torch .randn ((M , K ), dtype = in_dtype , device = 'cuda' )
326- b = torch .randn ((N , K ), dtype = in_dtype , device = 'cuda' )
344+ a = torch .randn ((M , K ), dtype = in_dtype_a , device = 'cuda' )
345+ b = torch .randn ((N , K ), dtype = in_dtype_b , device = 'cuda' )
327346 b = b .T
328347
329348 ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
330349 else : # triton, different data types
331350 assert "triton" in provider
332- a , _ , a_scale = gen_input (M , K , in_dtype , False , 1 , device = 'cuda' )
333- b , _ , b_scale = gen_input (K , N , in_dtype , True , 2 , device = 'cuda' )
351+ a , _ , a_scale = gen_input (M , K , in_dtype_a , False , 1 , device = 'cuda' )
352+ b , _ , b_scale = gen_input (K , N , in_dtype_b , True , 2 , device = 'cuda' )
334353 # Allocates output.
335354 c = torch .empty ((M , N ), device = a .device , dtype = out_dtype )
336-
337- if dtype_is_8_bit (in_dtype ):
338- a_scale = a_scale .item ()
339- b_scale = b_scale .item ()
340- ms , min_ms , max_ms = triton .testing .do_bench (lambda : matmul (a , b , c , a_scale , b_scale , activation = "" ),
341- quantiles = quantiles )
355+ scale_a8_b8 = dtype_is_8_bit (in_dtype_a ) or dtype_is_8_bit (in_dtype_b )
356+ ms , min_ms , max_ms = triton .testing .do_bench (
357+ lambda : matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = scale_a8_b8 , activation = "" ), quantiles = quantiles )
342358 global verbose
343359 if verbose :
344360 print (f'SIZE: { M } ,{ N } ,{ K } Best tuning config: ({ matmul_kernel .best_config ()} )' )
0 commit comments