88
99from utils .benchmark_utils import get_available_models , get_model_configs
1010
11+ # TODO: Make this an argument, Benchmarking, testing code and kernel helper need to change for it.
12+ SCALE_BLOCK_SIZE = 128
13+
1114
1215@triton .autotune (
1316 configs = [
17+ triton .Config (
18+ {
19+ 'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 128 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 ,
20+ 'kpack' : 2 , 'matrix_instr_nonkdim' : 16
21+ }, num_warps = 4 , num_stages = 2 ),
1422 triton .Config (
1523 {
1624 'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 ,
@@ -60,7 +68,13 @@ def matmul_kernel(
6068 stride_cn ,
6169 a_scale_ptr ,
6270 b_scale_ptr ,
71+ stride_ascale_m ,
72+ stride_ascale_k ,
73+ stride_bscale_k ,
74+ stride_bscale_n ,
6375 # Meta-parameters
76+ GROUP_K : tl .constexpr ,
77+ GROUP_N : tl .constexpr ,
6478 BLOCK_SIZE_M : tl .constexpr ,
6579 BLOCK_SIZE_N : tl .constexpr ,
6680 BLOCK_SIZE_K : tl .constexpr ,
@@ -76,12 +90,19 @@ def matmul_kernel(
7690
7791 NUM_XCDS : tl .constexpr = 8
7892
93+ tl .static_assert (((APPLY_SCALE is None ) or (APPLY_SCALE == 'tensor' )) or (APPLY_SCALE == 'block' ),
94+ f"Scaling mode { APPLY_SCALE } is not supported!!!" )
95+
7996 tl .assume (stride_am > 0 )
8097 tl .assume (stride_ak > 0 )
8198 tl .assume (stride_bk > 0 )
8299 tl .assume (stride_bn > 0 )
83100 tl .assume (stride_cm > 0 )
84101 tl .assume (stride_cn > 0 )
102+ tl .assume (stride_ascale_m > 0 )
103+ tl .assume (stride_ascale_k > 0 )
104+ tl .assume (stride_bscale_k > 0 )
105+ tl .assume (stride_bscale_n > 0 )
85106
86107 # -----------------------------------------------------------
87108 # Map program ids `pid` to the block of C it should compute.
@@ -132,9 +153,16 @@ def matmul_kernel(
132153 offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
133154 a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
134155 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
135- if APPLY_SCALE :
136- a_scale = tl .load (a_scale_ptr ) if ( a_scale_ptr ) else 1.0
156+ if APPLY_SCALE == 'tensor' :
157+ a_scale = tl .load (a_scale_ptr ) if a_scale_ptr else 1.0
137158 b_scale = tl .load (b_scale_ptr )
159+ elif APPLY_SCALE == 'block' :
160+ k_start = 0
161+ offs_ks = k_start // GROUP_K
162+ a_scale_ptrs = None if a_scale_ptr is None else (a_scale_ptr + offs_am * stride_ascale_m +
163+ offs_ks * stride_ascale_k )
164+ offs_bsn = offs_bn // GROUP_N
165+ b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bscale_n + offs_ks * stride_bscale_k
138166
139167 acc_dtype = tl .float32 if c_ptr .type .element_ty != tl .int8 else tl .int32
140168 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = acc_dtype )
@@ -148,15 +176,37 @@ def matmul_kernel(
148176 else :
149177 a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
150178 b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
179+
180+ if APPLY_SCALE == 'block' :
181+ b_scale = tl .load (b_scale_ptrs )
182+ if a_scale_ptrs is not None :
183+ a_scale = tl .load (a_scale_ptrs )
184+
151185 # Type conversion to support mixed precision GEMMs where b is lower precision than a
152186 b = b .to (a_ptr .type .element_ty )
153- accumulator += tl .dot (a , b , input_precision = "ieee" )
187+
188+ if APPLY_SCALE == 'block' :
189+ if a_scale_ptrs is not None :
190+ accumulator += tl .dot (a , b , input_precision = "ieee" ) * a_scale [:, None ] * b_scale [None , :]
191+ else :
192+ accumulator += tl .dot (a , b , input_precision = "ieee" ) * b_scale [None , :]
193+ else :
194+ accumulator += tl .dot (a , b , input_precision = "ieee" )
154195
155196 # Advance the ptrs to the next K block.
156197 a_ptrs += BLOCK_SIZE_K * stride_ak
157198 b_ptrs += BLOCK_SIZE_K * stride_bk
199+
200+ if APPLY_SCALE == 'block' :
201+ k_cur = k * BLOCK_SIZE_K // GROUP_K
202+ k_nxt = (k + 1 ) * BLOCK_SIZE_K // GROUP_K
203+ offs_ks = k_nxt - k_cur
204+ b_scale_ptrs += offs_ks * stride_bscale_k
205+ if a_scale_ptrs is not None :
206+ a_scale_ptrs += offs_ks * stride_ascale_k
207+
158208 # Apply scale to recover dynamic range reduced due to lower precision inputs.
159- if APPLY_SCALE :
209+ if APPLY_SCALE == 'tensor' :
160210 accumulator = accumulator * a_scale * b_scale
161211 # Apply activation function, if specified.
162212 # TODO(vgokhale): Add different types of activations.
@@ -180,13 +230,14 @@ def leaky_relu(x):
180230
181231
182232# Wrapper for gemm kernel.
183- def matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = False , activation = "" ):
233+ def matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = None , activation = "" ):
184234 # Check constraints.
185235 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions!!!"
186236 assert (a .element_size ()
187237 >= b .element_size ()), "Mixed dtype GEMMs are only supported when data type of a is bigger than b!!!"
188238 assert (a .is_floating_point () == b .is_floating_point ()
189239 ), "GEMMs between float and integer type tensors are not supported!!!"
240+ assert (scale_a8_b8 in [None , 'tensor' , 'block' ]), f"Scaling mode { scale_a8_b8 } is not supported!!!"
190241 M , K = a .shape
191242 K , N = b .shape
192243 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
@@ -205,6 +256,12 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
205256 c .stride (1 ),
206257 a_scale ,
207258 b_scale ,
259+ a_scale .stride (0 ) if (a_scale is not None ) and a_scale .ndim else 0 ,
260+ a_scale .stride (1 ) if (a_scale is not None ) and a_scale .ndim else 0 ,
261+ b_scale .stride (0 ) if (b_scale is not None ) and b_scale .ndim else 0 ,
262+ b_scale .stride (1 ) if (b_scale is not None ) and b_scale .ndim else 0 ,
263+ GROUP_K = SCALE_BLOCK_SIZE ,
264+ GROUP_N = SCALE_BLOCK_SIZE ,
208265 APPLY_SCALE = scale_a8_b8 ,
209266 ACTIVATION = activation ,
210267 )
@@ -243,7 +300,7 @@ def dtype_is_8_bit(dtype):
243300 (dtype is torch .int8 )
244301
245302
246- def gen_input (M , N , dtype , needTrans , seed , device = 'cuda' ):
303+ def gen_input (M , N , dtype , needTrans , seed = 0 , fp8_scaling_mode = 'tensor' , device = 'cuda' ):
247304 torch .manual_seed (seed )
248305
249306 if needTrans :
@@ -252,9 +309,28 @@ def gen_input(M, N, dtype, needTrans, seed, device='cuda'):
252309 raw_data = torch .randn ((M , N ), dtype = torch .float32 , device = 'cuda' )
253310 scale = None
254311 if dtype_is_8_bit (dtype ):
255- max_val = torch .max (torch .abs (raw_data ))
256- scale = max_val / dtype_max [dtype ]
257- raw_data = raw_data / scale
312+ if fp8_scaling_mode == 'token' :
313+ assert raw_data .size (1 ) % SCALE_BLOCK_SIZE == 0
314+ raw_data = raw_data .view (M , - 1 , SCALE_BLOCK_SIZE )
315+ max_val = raw_data .abs ().float ().amax (dim = 2 ).view (M , - 1 ).clamp (1e-4 )
316+ scale = max_val .unsqueeze (2 ) / dtype_max [dtype ]
317+ raw_data = (raw_data / scale ).view (M , N )
318+ scale = scale .view (M , - 1 )
319+ scale = scale .T .contiguous ().T
320+ elif fp8_scaling_mode == 'block' :
321+ x_padded = torch .zeros ((triton .cdiv (M , SCALE_BLOCK_SIZE ) * SCALE_BLOCK_SIZE ,
322+ triton .cdiv (N , SCALE_BLOCK_SIZE ) * SCALE_BLOCK_SIZE ), dtype = raw_data .dtype ,
323+ device = raw_data .device )
324+ x_padded [:M , :N ] = raw_data
325+ x_view = x_padded .view (- 1 , SCALE_BLOCK_SIZE , x_padded .size (1 ) // SCALE_BLOCK_SIZE , SCALE_BLOCK_SIZE )
326+ x_amax = x_view .abs ().float ().amax (dim = (1 , 3 ), keepdim = True ).clamp (1e-4 )
327+ x_scaled = x_view * (dtype_max [dtype ] / x_amax )
328+ raw_data = x_scaled .view_as (x_padded )[:M , :N ].T .contiguous ().T
329+ scale = (x_amax / dtype_max [dtype ]).view (x_view .size (0 ), x_view .size (2 ))
330+ elif fp8_scaling_mode == 'tensor' :
331+ max_val = torch .max (torch .abs (raw_data ))
332+ scale = max_val / dtype_max [dtype ]
333+ raw_data = raw_data / scale
258334
259335 input = raw_data .to (dtype )
260336 input_f32 = input .to (torch .float32 )
@@ -289,21 +365,21 @@ def get_x_vals():
289365def test_correctness (M , N , K , col_a , col_b , in_dtype_a , in_dtype_b , out_dtype ):
290366 torch_in_dtype_a = name_to_torch_types [in_dtype_a ]
291367 torch_in_dtype_b = name_to_torch_types [in_dtype_b ]
292- a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , 1 , device = 'cuda' )
293- b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , 2 , device = 'cuda' )
368+ a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , seed = 1 , device = 'cuda' )
369+ b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , seed = 2 , device = 'cuda' )
294370 torch_out_dtype = name_to_torch_types [out_dtype ]
295371 c = torch .empty ((M , N ), device = a .device , dtype = torch_out_dtype )
296372 # For 8-bit, we have scaled to the dynamic range of the data type.
297373 # This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
298374 # If we use fp16 it is possible to return infs from the torch.matmul call.
299375 if dtype_is_8_bit (torch_in_dtype_a ) or dtype_is_8_bit (torch_in_dtype_b ):
300- matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = True , activation = "" )
376+ matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = 'tensor' , activation = "" )
301377 torch_output = torch .matmul (a_fp32 , b_fp32 )
302378 # Set a_scale to 1.0 if it is not set
303379 torch_output = torch_output * (a_scale or 1.0 ) * b_scale
304380 # For other dtypes, use the same torch matmul as the dtype.
305381 else :
306- matmul (a , b , c , a_scale = None , b_scale = None , scale_a8_b8 = False , activation = "" )
382+ matmul (a , b , c , a_scale = None , b_scale = None , scale_a8_b8 = None , activation = "" )
307383 torch_output = torch .matmul (a .to (torch_in_dtype_a ), b .to (torch_in_dtype_b ))
308384 if out_dtype == 'int8' :
309385 torch .testing .assert_close (c .to (torch .float32 ),
@@ -312,6 +388,61 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype_a, in_dtype_b, out_dtype):
312388 torch .testing .assert_close (c , torch_output .to (torch_out_dtype ), atol = 5e-3 , rtol = 1e-2 )
313389
314390
391+ # yapf: disable
392+ @pytest .mark .parametrize (
393+ "M, N, K, in_dtype_a, in_dtype_b, out_dtype, col_a, col_b" ,
394+ [(* shape , in_dtype_a , in_dtype_b , out_dtype , col_a , col_b )
395+ for shape in get_x_vals ()
396+ for in_dtype_a , in_dtype_b , out_dtype in [
397+ ('fp8e4' , 'fp8e4' , 'fp16' ), ('fp8e5' , 'fp8e5' , 'fp16' ), ('fp16' , 'fp8e4' , 'fp16' ),
398+ ('fp16' , 'fp8e5' , 'fp16' ), ('bf16' , 'fp8e4' , 'bf16' ), ('bf16' , 'fp8e5' , 'bf16' )]
399+ # Defines if a matrix is row or column major.
400+ for col_a in [True , False ]
401+ for col_b in [True , False ]])
402+ # yapf: enable
403+ def test_correctness_block_scaling (M , N , K , col_a , col_b , in_dtype_a , in_dtype_b , out_dtype ):
404+ if (N % SCALE_BLOCK_SIZE != 0 ) or (K % SCALE_BLOCK_SIZE != 0 ):
405+ pytest .skip ("Skip N/K sizes not aligned to SCALE_BLOCK_SIZE" )
406+ # Generate Inputs
407+ torch_in_dtype_a = name_to_torch_types [in_dtype_a ]
408+ torch_in_dtype_b = name_to_torch_types [in_dtype_b ]
409+ a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , seed = 1 , fp8_scaling_mode = 'token' , device = 'cuda' )
410+ b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , seed = 2 , fp8_scaling_mode = 'block' , device = 'cuda' )
411+ # Create output tensor
412+ torch_out_dtype = name_to_torch_types [out_dtype ]
413+ c = torch .empty ((M , N ), device = a .device , dtype = torch_out_dtype )
414+ # For 8-bit, we have scaled to the dynamic range of the data type.
415+ # This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
416+ # If we use fp16 it is possible to return infs from the torch.matmul call.
417+ matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = 'block' , activation = "" )
418+ # Reference Implementation
419+ block_k = SCALE_BLOCK_SIZE
420+ block_n = SCALE_BLOCK_SIZE
421+ k_tiles = triton .cdiv (K , block_k )
422+ n_tiles = triton .cdiv (N , block_n )
423+ c_ref = torch .zeros ((M , N ), device = a_fp32 .device , dtype = torch .float32 )
424+
425+ A_tiles = [a_fp32 [:, i * block_k :min ((i + 1 ) * block_k , K )] for i in range (k_tiles )]
426+ B_tiles = [[
427+ b_fp32 [
428+ i * block_k :min ((i + 1 ) * block_k , K ),
429+ j * block_n :min ((j + 1 ) * block_n , N ),
430+ ] for j in range (n_tiles )
431+ ] for i in range (k_tiles )]
432+ C_tiles = [c_ref [:, j * block_n :min ((j + 1 ) * block_n , N )] for j in range (n_tiles )]
433+ As_tiles = [a_scale [:, i :i + 1 ] for i in range (k_tiles )] if (a_scale is not None ) else None
434+
435+ for i in range (k_tiles ):
436+ for j in range (n_tiles ):
437+ a_tile = A_tiles [i ]
438+ b_tile = B_tiles [i ][j ]
439+ c_tile = C_tiles [j ]
440+ s_tile = (As_tiles [i ] * b_scale [i ][j ]) if dtype_is_8_bit (torch_in_dtype_a ) else b_scale [i ][j ]
441+ c_tile [:, :] += torch .matmul (a_tile , b_tile ) * s_tile
442+
443+ torch .testing .assert_close (c , c_ref .to (torch_out_dtype ), atol = 5e-3 , rtol = 1e-2 )
444+
445+
315446def get_type (provider ):
316447 res = re .findall (r'\(.*?\)' , provider )
317448 return res [0 ][1 :- 1 ].split ('/' , 1 )
@@ -341,16 +472,28 @@ def benchmark(M, N, K, provider, model=None, args=None):
341472
342473 quantiles = [0.5 , 0.2 , 0.8 ]
343474 layout_tn = args .layout == 'tn'
344- a , _ , a_scale = gen_input (M , K , in_dtype_a , False , 1 , device = 'cuda' )
345- b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , 2 , device = 'cuda' )
475+
476+ if args .fp8_scaling_mode == 'tensor' or in_dtype_b == torch .int8 :
477+ a , _ , a_scale = gen_input (M , K , in_dtype_a , False , seed = 1 , device = 'cuda' )
478+ b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , seed = 2 , device = 'cuda' )
479+ else :
480+ a , _ , a_scale = gen_input (M , K , in_dtype_a , False , seed = 1 , fp8_scaling_mode = 'token' , device = 'cuda' )
481+ b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , seed = 2 , fp8_scaling_mode = 'block' , device = 'cuda' )
482+
346483 if 'hipblaslt' in provider :
347484 ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
348485 else : # triton, different data types
349486 assert "triton" in provider
350487 # Allocates output.
351488 c = torch .empty ((M , N ), device = a .device , dtype = out_dtype )
352489
353- scale_a8_b8 = dtype_is_8_bit (in_dtype_a ) or dtype_is_8_bit (in_dtype_b )
490+ # If data type is 8 bit
491+ # Default to tensor scaling if scaling mode is tensor or dtype is int8
492+ # Use block scaling otherwise
493+ scale_a8_b8 = None
494+ if dtype_is_8_bit (in_dtype_a ) or dtype_is_8_bit (in_dtype_b ):
495+ scale_a8_b8 = 'tensor' if in_dtype_b == torch .int8 else args .fp8_scaling_mode
496+
354497 ms , min_ms , max_ms = triton .testing .do_bench (
355498 lambda : matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = scale_a8_b8 , activation = "" ), quantiles = quantiles )
356499 if args .v :
@@ -381,6 +524,8 @@ def parse_args():
381524 parser .add_argument ("-dtype" , type = str , default = None , help = "Data type of inputs and outputs" )
382525 parser .add_argument ("-b_dtype" , type = str , default = None ,
383526 help = "Data type of B operand, if specified (else same as dtype)" )
527+ parser .add_argument ("-fp8_scaling_mode" , type = str , default = 'tensor' , choices = ['tensor' , 'block' ],
528+ help = "Type of scaling to apply when either or both inputs are fp8" )
384529
385530 args = parser .parse_args ()
386531
0 commit comments