@@ -114,6 +114,7 @@ def matmul_kernel(
114114 stride_cn ,
115115 chunk_trun_bits ,
116116 max_acc_bits , # pylint: disable=unused-argument
117+ clamp_acc_to_dl16 ,
117118 truncate_then_accumulate ,
118119 # Meta-parameters
119120 BLOCK_SIZE_M : tl .constexpr ,
@@ -159,13 +160,8 @@ def matmul_kernel(
159160 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
160161 # of fp32 values for higher accuracy.
161162 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
162- ## ------ prepare LSB rounding/truncation masks -------
163- # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
164- # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
165- # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
166- trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
167- round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
168- ## ---------------------------------------------------------
163+ ## ------ prepare LSB rounding/truncation masks outside the for loop -------
164+ round_bit , trun_mask = round_and_trun_mask (chunk_trun_bits , clamp_acc_to_dl16 )
169165
170166 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
171167 # Load the next block of A and B, generate a mask by checking the K dimension.
@@ -180,8 +176,10 @@ def matmul_kernel(
180176 # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
181177
182178 ## ------ add chunky LSB rounding/masking --------
183- if chunk_trun_bits > 0 :
184- accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
179+ if clamp_acc_to_dl16 or chunk_trun_bits > 0 :
180+ accumulator_inner = round_and_trun (
181+ accumulator_inner , round_bit , trun_mask , clamp_acc_to_dl16
182+ )
185183 ## ---------------------------------------------------------
186184 if truncate_then_accumulate :
187185 accumulator += accumulator_inner
@@ -226,6 +224,7 @@ def imatmul_kernel(
226224 stride_cn ,
227225 chunk_trun_bits ,
228226 max_acc_bits ,
227+ clamp_acc_to_dl16 , # pylint: disable=unused-argument
229228 truncate_then_accumulate ,
230229 # Meta-parameters
231230 BLOCK_SIZE_M : tl .constexpr ,
@@ -324,6 +323,7 @@ def matmul_kernel_DABC(
324323 stride_cn ,
325324 chunk_trun_bits ,
326325 max_acc_bits , # pylint: disable=unused-argument
326+ clamp_acc_to_dl16 ,
327327 truncate_then_accumulate ,
328328 # Meta-parameters
329329 BLOCK_SIZE_M : tl .constexpr ,
@@ -377,13 +377,8 @@ def matmul_kernel_DABC(
377377 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
378378 # of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
379379 accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 )
380- ## ------ prepare LSB rounding/truncation masks -------
381- # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
382- # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
383- # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
384- trun_mask = tl .cast ((0xFFFFFFFF >> chunk_trun_bits ) << chunk_trun_bits , tl .uint32 )
385- round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
386- ## ---------------------------------------------------------
380+ ## ------ prepare LSB rounding/truncation masks outside the for loop -------
381+ round_bit , trun_mask = round_and_trun_mask (chunk_trun_bits , clamp_acc_to_dl16 )
387382
388383 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
389384 # Load the next block of A, B, and C, generate a mask by checking the K dimension.
@@ -403,8 +398,10 @@ def matmul_kernel_DABC(
403398 # precision as well, hence, could lose some precision!
404399
405400 ## ------ add chunky LSB rounding/masking --------
406- if chunk_trun_bits > 0 :
407- accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
401+ if clamp_acc_to_dl16 or chunk_trun_bits > 0 :
402+ accumulator_inner = round_and_trun (
403+ accumulator_inner , round_bit , trun_mask , clamp_acc_to_dl16
404+ )
408405 ## ---------------------------------------------------------
409406 if truncate_then_accumulate :
410407 accumulator += accumulator_inner
@@ -433,9 +430,39 @@ def leaky_relu(x):
433430
434431
435432@triton .jit
436- def round_and_trun (x , round_bit , trun_mask ):
433+ def round_and_trun_mask (chunk_trun_bits , clamp_acc_to_dl16 ):
434+ """
435+ Rounding and LSB truncation masks only need to be generated once.
436+ These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate
437+ up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias.
438+ Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
439+ 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
440+ """
441+ if clamp_acc_to_dl16 :
442+ # DL16 is e6m9, hence, truncate 23 - 9 = 14 bits
443+ chunk_trun_bits = 14
444+ round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
445+ trun_mask = ~ tl .cast ((1 << chunk_trun_bits ) - 1 , tl .uint32 )
446+ return round_bit , trun_mask
447+
448+
449+ @triton .jit
450+ def round_and_trun (x , round_bit , trun_mask , clamp_acc_to_dl16 ):
437451 """Round and truncate (usually for accumulator)."""
438- return libdevice .uint_as_float ((libdevice .float_as_uint (x ) + round_bit ) & trun_mask )
452+ x = libdevice .uint_as_float ((libdevice .float_as_uint (x ) + round_bit ) & trun_mask )
453+
454+ if clamp_acc_to_dl16 :
455+ # clamp to DL16 min/max:
456+ # max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0
457+ # greater than this will become +inf (or -inf)
458+ # min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10
459+ # smaller than this will become 0
460+ dl16_max = 8581545984.0
461+ dl16_min = 4.665707820095122e-10
462+ x = tl .where (x >= dl16_max , float ("inf" ), x )
463+ x = tl .where (x <= - dl16_max , float ("-inf" ), x )
464+ x = tl .where (tl .abs (x ) < dl16_min , 0 , x )
465+ return x
439466
440467
441468def tl_matmul_chunk_truncate (
@@ -448,6 +475,7 @@ def tl_matmul_chunk_truncate(
448475 max_acc_bits = 32 ,
449476 truncate_then_accumulate = True ,
450477 cast_output_to_input_dtype = None ,
478+ clamp_acc_to_dl16 = False ,
451479):
452480 """Triton matmul for HW behavior simulation. Supports float and int8.
453481 i. variable chunk size (i.e., BLOCK_SIZE_K)
@@ -461,7 +489,8 @@ def tl_matmul_chunk_truncate(
461489 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
462490 max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
463491 clamp each chunk of a*b to [-2**23-1, 2**23].
464- (assuming no inf when overflow)
492+ (only used by INT)
493+ clamp_acc_to_dl16(bool): Only used by FP8, whether to clamp local accumulator (FP32) to DL16
465494 truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
466495 c = truncate(a*b+c)
467496 cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
@@ -473,7 +502,7 @@ def tl_matmul_chunk_truncate(
473502
474503 NOTE:
475504 use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
476- real model inference. otherwise auto-tune will be triggered in every forward call.
505+ real model inference. otherwise auto-tune may be triggered in every forward call.
477506 """
478507
479508 # Check constraints.
@@ -584,6 +613,7 @@ def grid(META):
584613 c .stride (1 ),
585614 chunk_trun_bits = chunk_trun_bits ,
586615 max_acc_bits = max_acc_bits ,
616+ clamp_acc_to_dl16 = clamp_acc_to_dl16 ,
587617 truncate_then_accumulate = truncate_then_accumulate ,
588618 ACTIVATION = activation ,
589619 ** kernel_config , # if using auto-tune, comment this line out.
0 commit comments