@@ -160,13 +160,8 @@ def matmul_kernel(
160160 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
161161 # of fp32 values for higher accuracy.
162162 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
163- ## ------ prepare LSB rounding/truncation masks -------
164- # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
165- # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
166- # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
167- trun_mask = ~ tl .cast ((1 << chunk_trun_bits ) - 1 , tl .uint32 )
168- round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
169- ## ---------------------------------------------------------
163+ ## ------ prepare LSB rounding/truncation masks outside the for loop -------
164+ round_bit , trun_mask = round_and_trun_mask (chunk_trun_bits , clamp_acc_to_dl16 )
170165
171166 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
172167 # Load the next block of A and B, generate a mask by checking the K dimension.
@@ -181,10 +176,10 @@ def matmul_kernel(
181176 # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
182177
183178 ## ------ add chunky LSB rounding/masking --------
184- if chunk_trun_bits > 0 :
185- accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
186- if clamp_acc_to_dl16 :
187- accumulator_inner = fp32_clamp_to_dl16 ( accumulator_inner )
179+ if clamp_acc_to_dl16 or chunk_trun_bits > 0 :
180+ accumulator_inner = round_and_trun (
181+ accumulator_inner , round_bit , trun_mask , clamp_acc_to_dl16
182+ )
188183 ## ---------------------------------------------------------
189184 if truncate_then_accumulate :
190185 accumulator += accumulator_inner
@@ -382,13 +377,8 @@ def matmul_kernel_DABC(
382377 # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
383378 # of fp32 values for higher accuracy, i.e. C should have been cast to fp32 already
384379 accumulator = tl .load (c_ptrs , mask = c_mask , other = 0.0 )
385- ## ------ prepare LSB rounding/truncation masks -------
386- # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
387- # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
388- # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
389- trun_mask = ~ tl .cast ((1 << chunk_trun_bits ) - 1 , tl .uint32 )
390- round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
391- ## ---------------------------------------------------------
380+ ## ------ prepare LSB rounding/truncation masks outside the for loop -------
381+ round_bit , trun_mask = round_and_trun_mask (chunk_trun_bits , clamp_acc_to_dl16 )
392382
393383 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
394384 # Load the next block of A, B, and C, generate a mask by checking the K dimension.
@@ -408,10 +398,10 @@ def matmul_kernel_DABC(
408398 # precision as well, hence, could lose some precision!
409399
410400 ## ------ add chunky LSB rounding/masking --------
411- if chunk_trun_bits > 0 :
412- accumulator_inner = round_and_trun (accumulator_inner , round_bit , trun_mask )
413- if clamp_acc_to_dl16 :
414- accumulator_inner = fp32_clamp_to_dl16 ( accumulator_inner )
401+ if clamp_acc_to_dl16 or chunk_trun_bits > 0 :
402+ accumulator_inner = round_and_trun (
403+ accumulator_inner , round_bit , trun_mask , clamp_acc_to_dl16
404+ )
415405 ## ---------------------------------------------------------
416406 if truncate_then_accumulate :
417407 accumulator += accumulator_inner
@@ -440,34 +430,64 @@ def leaky_relu(x):
440430
441431
442432@triton .jit
443- def round_and_trun (x , round_bit , trun_mask ):
444- """Round and truncate (usually for accumulator)."""
445- return libdevice .uint_as_float ((libdevice .float_as_uint (x ) + round_bit ) & trun_mask )
433+ def round_and_trun_mask (chunk_trun_bits , clamp_acc_to_dl16 ):
434+ """
435+ Rounding and LSB truncation masks only need to be generated once.
436+ These mask will be applied on "inner" accumulator, which is alway FP32 (e8m23). We may truncate
437+ up to 23b for mantissa. If DL16/DL8, pay attention to exponent bias.
438+ Examples: 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
439+ 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
440+ """
441+ if clamp_acc_to_dl16 :
442+ # DL16 is e6m9, hence, truncate 23 - 9 = 14 bits
443+ chunk_trun_bits = 14
444+ round_bit = 1 << (chunk_trun_bits - 1 ) if chunk_trun_bits > 0 else 0
445+ trun_mask = ~ tl .cast ((1 << chunk_trun_bits ) - 1 , tl .uint32 )
446+ return round_bit , trun_mask
446447
447448
448449@triton .jit
449- def fp32_clamp_to_dl16 (x ):
450- """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
451- # 1. rounding: add round bit, zero out last 13 bits, back to float
452- x = libdevice .float_as_uint (x )
453- round_bit = 1 << (23 - 9 - 1 )
454- mask_13x0 = ~ tl .cast ((1 << 13 ) - 1 , tl .uint32 )
455- x = libdevice .uint_as_float ((x + round_bit ) & mask_13x0 )
456-
457- # 2. clamp to min/max:
458- # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf
459- # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0
460- # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this
461- # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10
462- dl16_max = 8581545984.0
463- dl16_min = 4.665707820095122e-10
464- x = tl .where (x >= dl16_max , float ("inf" ), x )
465- x = tl .where (x <= - dl16_max , float ("-inf" ), x )
466- x = tl .where (tl .abs (x ) < dl16_min , 0 , x )
467-
450+ def round_and_trun (x , round_bit , trun_mask , clamp_acc_to_dl16 ):
451+ """Round and truncate (usually for accumulator)."""
452+ x = libdevice .uint_as_float ((libdevice .float_as_uint (x ) + round_bit ) & trun_mask )
453+
454+ if clamp_acc_to_dl16 :
455+ # clamp to DL16 min/max:
456+ # max = 2^32 * 1.(1111 1111 0)_base2 = 2^32*(2 - 2^-9) = 8581545984.0
457+ # greater than this will become +inf (or -inf)
458+ # min = 2^-31 * 1.(0000 0000 1)_base2 = 2^-31*(1 + 2^-9)> = 4.665707820095122e-10
459+ # smaller than this will become 0
460+ dl16_max = 8581545984.0
461+ dl16_min = 4.665707820095122e-10
462+ x = tl .where (x >= dl16_max , float ("inf" ), x )
463+ x = tl .where (x <= - dl16_max , float ("-inf" ), x )
464+ x = tl .where (tl .abs (x ) < dl16_min , 0 , x )
468465 return x
469466
470467
468+ # @triton.jit
469+ # def fp32_clamp_to_dl16(x):
470+ # """clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
471+ # # 1. rounding: add round bit, zero out last 13 bits, back to float
472+ # x = libdevice.float_as_uint(x)
473+ # round_bit = 1 << (23 - 9 - 1)
474+ # mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32)
475+ # x = libdevice.uint_as_float((x + round_bit) & mask_13x0)
476+
477+ # # 2. clamp to min/max:
478+ # # max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf
479+ # # (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0
480+ # # min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this
481+ # # (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10
482+ # dl16_max = 8581545984.0
483+ # dl16_min = 4.665707820095122e-10
484+ # x = tl.where(x >= dl16_max, float("inf"), x)
485+ # x = tl.where(x <= -dl16_max, float("-inf"), x)
486+ # x = tl.where(tl.abs(x) < dl16_min, 0, x)
487+
488+ # return x
489+
490+
471491def tl_matmul_chunk_truncate (
472492 a ,
473493 b ,
0 commit comments