2020import triton
2121import triton .language as tl
2222
23+ DTYPE_I8 = [torch .int8 ]
24+ DTYPE_F8 = [torch .float8_e4m3fn , torch .float8_e5m2 ]
25+ DTYPE_8BIT = DTYPE_I8 + DTYPE_F8
26+
2327
2428def get_cuda_autotune_config (chunk_size = None ):
2529 """Basic use of triton.Config() is like:
@@ -145,8 +149,7 @@ def matmul_kernel(
145149 # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
146150 # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
147151 # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
148- full_32b_mask = 0xFFFFFFFF
149- trun_mask = (full_32b_mask << chunk_trun_bits ) & full_32b_mask
152+ trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
150153 round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
151154 ## ---------------------------------------------------------
152155
@@ -160,7 +163,7 @@ def matmul_kernel(
160163 # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
161164
162165 ## ------ add chunky LSB rounding/masking --------
163- if chunk_trun_bits != 0 :
166+ if chunk_trun_bits > 0 :
164167 accumulator = libdevice .uint_as_float (
165168 (libdevice .float_as_uint (accumulator ) + round_bit ) & trun_mask
166169 )
@@ -269,7 +272,14 @@ def leaky_relu(x):
269272 return tl .where (x >= 0 , x , 0.01 * x )
270273
271274
272- def tl_matmul_chunk_truncate (a , b , activation = "" , chunk_trun_bits = 0 , chunk_size = 16 ):
275+ def tl_matmul_chunk_truncate (
276+ a ,
277+ b ,
278+ activation = "" ,
279+ chunk_trun_bits = 0 ,
280+ chunk_size = 16 ,
281+ cast_output_to_input_dtype = None ,
282+ ):
273283 """Triton matmul for HW behavior simulation. Supports float and int8.
274284 a. variable chunk size (i.e., BLOCK_SIZE_K)
275285 b. LSB truncation, must <23 if using float.
@@ -279,6 +289,9 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
279289 activation (str, optional): activation func to be fused, see relu example.
280290 chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
281291 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
292+ cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293+ FP32 or INT32. by default we cast the final
294+ output to the same dtype as input for non-8bits.
282295
283296 Returns:
284297 _type_: _description_
@@ -292,30 +305,37 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
292305 assert a .is_contiguous (), "Matrix A must be contiguous"
293306 assert a .dtype == b .dtype , "Input dtypes inconsistent"
294307
308+ if cast_output_to_input_dtype is None :
309+ cast_output_to_input_dtype = a .dtype not in DTYPE_8BIT
295310 allowed_dtypes = [torch .float , torch .bfloat16 , torch .float16 ]
296311 cuda_cc = torch .cuda .get_device_capability ()
297312 if cuda_cc [0 ] >= 8 :
298- allowed_dtypes . append ( torch . int8 )
313+ allowed_dtypes += DTYPE_I8
299314 if cuda_cc [0 ] >= 9 or cuda_cc == (8 , 9 ):
300- allowed_dtypes += [ torch . float8_e4m3fn , torch . float8_e5m2 ]
315+ allowed_dtypes += DTYPE_F8
301316 assert a .dtype in allowed_dtypes , "Input dtype is not supported"
302317 M , K = a .shape
303318 K , N = b .shape
304319
305- # Allocates output, always accumulate in FP32/INT32 then cast (if floats)
320+ # Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
306321 def isPowerofTwo (x ):
307322 """triton-specific limitation: block size needs to be power of 2."""
308323 return (x & (x - 1 )) == 0
309324
310- if a .dtype == torch .int8 :
325+ min_chunk_size = 32 if a .dtype in DTYPE_8BIT else 16
326+ if isPowerofTwo (chunk_size ):
327+ chunk_size = max (chunk_size , min_chunk_size )
328+ else :
329+ chunk_size = min_chunk_size
330+
331+ if a .dtype in DTYPE_I8 :
332+ acc_dtype = torch .int32
311333 mm_kernel = imatmul_kernel
312- chunk_size = max (chunk_size , 32 ) if isPowerofTwo (chunk_size ) else 32
313- c = torch .zeros ((M , N ), device = a .device , dtype = torch .int32 )
314334 else :
315- assert chunk_trun_bits < 23 , "FP32 accumulator only has 23 mantissa bits"
335+ acc_dtype = torch . float32
316336 mm_kernel = matmul_kernel
317- chunk_size = max ( chunk_size , 16 ) if isPowerofTwo ( chunk_size ) else 16
318- c = torch .zeros ((M , N ), device = a .device , dtype = torch . float32 )
337+ assert chunk_trun_bits < 23 , "FP32 accumulator only has 23 mantissa bits"
338+ c = torch .zeros ((M , N ), device = a .device , dtype = acc_dtype )
319339
320340 # 1D launch kernel where each block gets its own program.
321341 def grid (META ):
@@ -327,7 +347,7 @@ def grid(META):
327347 kernel_config = {
328348 "BLOCK_SIZE_M" : 128 ,
329349 "BLOCK_SIZE_K" : chunk_size ,
330- "BLOCK_SIZE_N" : 128 , # was 32
350+ "BLOCK_SIZE_N" : 32 ,
331351 "GROUP_SIZE_M" : 8 ,
332352 "num_warps" : 2 ,
333353 "num_stages" : 5 ,
@@ -336,7 +356,7 @@ def grid(META):
336356 kernel_config = {
337357 "BLOCK_SIZE_M" : 128 ,
338358 "BLOCK_SIZE_K" : chunk_size ,
339- "BLOCK_SIZE_N" : 128 , # was 64
359+ "BLOCK_SIZE_N" : 64 ,
340360 "GROUP_SIZE_M" : 8 ,
341361 "num_warps" : 4 ,
342362 "num_stages" : 4 ,
@@ -359,4 +379,4 @@ def grid(META):
359379 ACTIVATION = activation ,
360380 ** kernel_config , # if using auto-tune, comment this line out.
361381 )
362- return c .to (a .dtype ) if a . dtype != torch . int8 else c
382+ return c .to (a .dtype ) if cast_output_to_input_dtype else c
0 commit comments