2020import triton
2121import triton .language as tl
2222
23+ DTYPE_I8 = [torch .int8 ]
24+ DTYPE_F8 = [torch .float8_e4m3fn , torch .float8_e5m2 ]
25+ DTYPE_8BIT = DTYPE_I8 + DTYPE_F8
26+
2327
2428def get_cuda_autotune_config (chunk_size = None ):
2529 """Basic use of triton.Config() is like:
@@ -145,8 +149,7 @@ def matmul_kernel(
145149 # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
146150 # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
147151 # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
148- full_32b_mask = 0xFFFFFFFF
149- trun_mask = (full_32b_mask << chunk_trun_bits ) & full_32b_mask
152+ trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
150153 round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
151154 ## ---------------------------------------------------------
152155
@@ -160,7 +163,7 @@ def matmul_kernel(
160163 # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
161164
162165 ## ------ add chunky LSB rounding/masking --------
163- if chunk_trun_bits != 0 :
166+ if chunk_trun_bits > 0 :
164167 accumulator = libdevice .uint_as_float (
165168 (libdevice .float_as_uint (accumulator ) + round_bit ) & trun_mask
166169 )
@@ -269,7 +272,14 @@ def leaky_relu(x):
269272 return tl .where (x >= 0 , x , 0.01 * x )
270273
271274
272- def tl_matmul_chunk_truncate (a , b , activation = "" , chunk_trun_bits = 0 , chunk_size = 16 ):
275+ def tl_matmul_chunk_truncate (
276+ a ,
277+ b ,
278+ activation = "" ,
279+ chunk_trun_bits = 0 ,
280+ chunk_size = 16 ,
281+ cast_output_to_input_dtype = True ,
282+ ):
273283 """Triton matmul for HW behavior simulation. Supports float and int8.
274284 a. variable chunk size (i.e., BLOCK_SIZE_K)
275285 b. LSB truncation, must <23 if using float.
@@ -279,6 +289,10 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
279289 activation (str, optional): activation func to be fused, see relu example.
280290 chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
281291 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
292+ cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293+ FP32 or INT32. by default we cast the final
294+ output to the same dtype as input, but can be
295+ changed if needed.
282296
283297 Returns:
284298 _type_: _description_
@@ -295,27 +309,32 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
295309 allowed_dtypes = [torch .float , torch .bfloat16 , torch .float16 ]
296310 cuda_cc = torch .cuda .get_device_capability ()
297311 if cuda_cc [0 ] >= 8 :
298- allowed_dtypes . append ( torch . int8 )
312+ allowed_dtypes += DTYPE_I8
299313 if cuda_cc [0 ] >= 9 or cuda_cc == (8 , 9 ):
300- allowed_dtypes += [ torch . float8_e4m3fn , torch . float8_e5m2 ]
314+ allowed_dtypes += DTYPE_F8
301315 assert a .dtype in allowed_dtypes , "Input dtype is not supported"
302316 M , K = a .shape
303317 K , N = b .shape
304318
305- # Allocates output, always accumulate in FP32/INT32 then cast (if floats)
319+ # Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
306320 def isPowerofTwo (x ):
307321 """triton-specific limitation: block size needs to be power of 2."""
308322 return (x & (x - 1 )) == 0
309323
310- if a .dtype == torch .int8 :
324+ min_chunk_size = 32 if a .dtype in DTYPE_8BIT else 16
325+ if isPowerofTwo (chunk_size ):
326+ chunk_size = max (chunk_size , min_chunk_size )
327+ else :
328+ chunk_size = min_chunk_size
329+
330+ if a .dtype in DTYPE_I8 :
331+ acc_dtype = torch .int32
311332 mm_kernel = imatmul_kernel
312- chunk_size = max (chunk_size , 32 ) if isPowerofTwo (chunk_size ) else 32
313- c = torch .zeros ((M , N ), device = a .device , dtype = torch .int32 )
314333 else :
315- assert chunk_trun_bits < 23 , "FP32 accumulator only has 23 mantissa bits"
334+ acc_dtype = torch . float32
316335 mm_kernel = matmul_kernel
317- chunk_size = max ( chunk_size , 16 ) if isPowerofTwo ( chunk_size ) else 16
318- c = torch .zeros ((M , N ), device = a .device , dtype = torch . float32 )
336+ assert chunk_trun_bits < 23 , "FP32 accumulator only has 23 mantissa bits"
337+ c = torch .zeros ((M , N ), device = a .device , dtype = acc_dtype )
319338
320339 # 1D launch kernel where each block gets its own program.
321340 def grid (META ):
@@ -327,7 +346,7 @@ def grid(META):
327346 kernel_config = {
328347 "BLOCK_SIZE_M" : 128 ,
329348 "BLOCK_SIZE_K" : chunk_size ,
330- "BLOCK_SIZE_N" : 128 , # was 32
349+ "BLOCK_SIZE_N" : 32 ,
331350 "GROUP_SIZE_M" : 8 ,
332351 "num_warps" : 2 ,
333352 "num_stages" : 5 ,
@@ -336,7 +355,7 @@ def grid(META):
336355 kernel_config = {
337356 "BLOCK_SIZE_M" : 128 ,
338357 "BLOCK_SIZE_K" : chunk_size ,
339- "BLOCK_SIZE_N" : 128 , # was 64
358+ "BLOCK_SIZE_N" : 64 ,
340359 "GROUP_SIZE_M" : 8 ,
341360 "num_warps" : 4 ,
342361 "num_stages" : 4 ,
@@ -359,4 +378,4 @@ def grid(META):
359378 ACTIVATION = activation ,
360379 ** kernel_config , # if using auto-tune, comment this line out.
361380 )
362- return c .to (a .dtype ) if a . dtype != torch . int8 else c
381+ return c .to (a .dtype ) if cast_output_to_input_dtype else c
0 commit comments