Skip to content

Commit a8d6ea9

Browse files
zero out last 13 bits
Signed-off-by: cliu-us <[email protected]>
1 parent 005d54b commit a8d6ea9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,10 @@ def round_and_trun(x, round_bit, trun_mask):
448448
@triton.jit
449449
def fp32_clamp_to_dl16(x):
450450
"""clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
451-
# 1. rounding, add round bit to full uint representation
451+
# 1. rounding: add round bit to full uint representation, zero out last 13 bits, back to float
452452
x = libdevice.float_as_uint(x)
453453
round_bit = 1 << (23 - 9 - 1)
454-
x = libdevice.uint_as_float(x + round_bit)
454+
x = libdevice.uint_as_float(((x + round_bit) >> 13) << 13)
455455

456456
# 2. clamp to min/max:
457457
# max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf

0 commit comments

Comments
 (0)