Skip to content

Commit 423d734

Browse files
authored
Fix FP16 to Fp8E4M3 RTNE Upper bound to +/-448 (#3814)
This change fix previous FP16toFp8E4M3 RTNE upper bound from 256 to 448, according to https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1. Page 13 table 2: "S.1111.1102 = ±448" It helps to sovle SGLANG Group Quant UT failure #3613.
1 parent 9856962 commit 423d734

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits,
8787

8888

8989
@triton.jit
90-
def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr, device_ : tl.constexpr):
90+
def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr, from_bf16 : tl.constexpr):
9191

9292
tl.static_assert(x.dtype == tl.float32, "input must be float32")
9393
numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits
@@ -118,7 +118,7 @@ def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.const
118118
mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5)
119119
exponent = tl.where(exponent > -1, exponent, exponent + 1)
120120

121-
if device_ == 'xpu':
121+
if from_bf16:
122122
# convert mantissa to int with proper rounding without inline asm.
123123
to_cast = mantissa.to(tl.uint32, bitcast=True)
124124
mantissa2 = (to_cast & 0x7fffff)
@@ -165,22 +165,23 @@ def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.const
165165

166166

167167
@triton.jit
168-
def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr, device_: tl.constexpr):
168+
def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr, from_bf16: tl.constexpr):
169169

170170
tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32")
171171

172172
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
173173
x = tl.load(src + idxs)
174-
y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias, device_=device_)
174+
y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias, from_bf16)
175175
y = y.to(dst.dtype.element_ty, bitcast=True)
176176
tl.store(dst + idxs, y)
177177

178178

179179
def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096):
180180

181181
dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device)
182+
from_bf16 = src_dtype == tl.bfloat16 and device == 'xpu'
182183
downcast_emulated[(src.shape[0] // BLOCK_SIZE,)](
183-
triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias, device_=device)
184+
triton.reinterpret(src, tl.float32), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias, from_bf16=from_bf16)
184185
# 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will
185186
# convert -0. in higher precision to 0x80 and thus need to fix the result to 0.
186187
if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16:
@@ -243,7 +244,7 @@ def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits,
243244
else:
244245
src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device)
245246

246-
dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device)
247+
dst2 = launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device)
247248

248249
dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device)
249250
dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device)

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -527,57 +527,56 @@ static SmallVector<Value>
527527
Fp16_to_Fp8E4M3Nv_RTNE(Location loc, ConversionPatternRewriter &rewriter,
528528
const SmallVector<Value> &v) {
529529
auto b = TritonLLVMOpBuilder(loc, rewriter);
530-
Value val = b.zext(i32_ty, b.bitcast(v[0], i16_ty));
531-
Value sign = b.and_(i32_ty, val, b.i32_val(0x8000));
532-
Value nosign = b.and_(i32_ty, val, b.i32_val(0x7fff));
530+
Value val = b.bitcast(v[0], i16_ty);
531+
Value sign = b.and_(i16_ty, val, b.i16_val(0x8000));
532+
Value nosign = b.and_(i16_ty, val, b.i16_val(0x7fff));
533533

534-
Value exp = b.and_(i32_ty, b.lshr(nosign, b.i32_val(10)), b.i32_val(0x1f));
534+
Value exp = b.and_(i16_ty, b.lshr(nosign, b.i16_val(10)), b.i16_val(0x1f));
535535
// Check if we need a translation to a subnormal value. This happens when
536536
// exp value is in range [5, 8].
537537
Value is_subnormal =
538-
b.and_(b.icmp_uge(exp, b.i32_val(5)), b.icmp_ule(exp, b.i32_val(8)));
539-
Value shift = b.sub(i32_ty, b.i32_val(8), exp);
540-
Value subnormal = b.and_(i32_ty, nosign, b.i32_val(0x3ff));
541-
subnormal = b.or_(i32_ty, subnormal, b.i32_val(0x400));
538+
b.and_(b.icmp_uge(exp, b.i16_val(5)), b.icmp_ule(exp, b.i16_val(8)));
539+
Value shift = b.sub(i16_ty, b.i16_val(8), exp);
540+
Value subnormal = b.and_(i16_ty, nosign, b.i16_val(0x3ff));
541+
subnormal = b.or_(i16_ty, subnormal, b.i16_val(0x400));
542542
// Make rounding with respect to bits we are going to shift and cut off.
543-
Value round_step = b.shl(i32_ty, b.i32_val(0x100), shift);
544-
Value tail_mask = b.sub(i32_ty, round_step, b.i32_val(1));
545-
Value tail = b.and_(i32_ty, subnormal, tail_mask);
546-
Value threshold = b.shl(i32_ty, b.i32_val(0x80), shift);
543+
Value round_step = b.shl(i16_ty, b.i16_val(0x100), shift);
544+
Value tail_mask = b.sub(i16_ty, round_step, b.i16_val(1));
545+
Value tail = b.and_(i16_ty, subnormal, tail_mask);
546+
Value threshold = b.shl(i16_ty, b.i16_val(0x80), shift);
547547
Value odd_truncated =
548-
b.icmp_ne(b.and_(i32_ty, subnormal, round_step), b.i32_val(0));
548+
b.icmp_ne(b.and_(i16_ty, subnormal, round_step), b.i16_val(0));
549549
Value round_up = b.or_(b.icmp_ugt(tail, threshold),
550550
b.and_(b.icmp_eq(tail, threshold), odd_truncated));
551551
subnormal =
552-
b.select(round_up, b.add(i32_ty, subnormal, round_step), subnormal);
552+
b.select(round_up, b.add(i16_ty, subnormal, round_step), subnormal);
553553
// Now shift to get the final result.
554-
subnormal = b.lshr(i32_ty, subnormal, shift);
554+
subnormal = b.lshr(i16_ty, subnormal, shift);
555555

556556
// Normalized case. Start with rounding, then apply exp range to fit 4 bits,
557557
// adjust bias and shift left.
558558
// TODO: NaN values might be mishandled.
559-
tail = b.and_(i32_ty, nosign, b.i32_val(0x7f));
559+
tail = b.and_(i16_ty, nosign, b.i16_val(0x7f));
560560
odd_truncated =
561-
b.icmp_ne(b.and_(i32_ty, nosign, b.i32_val(0x80)), b.i32_val(0));
562-
round_up = b.or_(b.icmp_ugt(tail, b.i32_val(0x40)),
563-
b.and_(b.icmp_eq(tail, b.i32_val(0x40)), odd_truncated));
561+
b.icmp_ne(b.and_(i16_ty, nosign, b.i16_val(0x80)), b.i16_val(0));
562+
round_up = b.or_(b.icmp_ugt(tail, b.i16_val(0x40)),
563+
b.and_(b.icmp_eq(tail, b.i16_val(0x40)), odd_truncated));
564564
Value rounded =
565-
b.and_(i32_ty, b.add(i32_ty, nosign, b.i32_val(0x80)), b.i32_val(0x7f80));
566-
nosign = b.select(round_up, rounded, nosign);
565+
b.and_(i16_ty, b.add(i16_ty, nosign, b.i16_val(0x80)), b.i16_val(0x7f80));
566+
Value normal = b.select(round_up, rounded, nosign);
567567

568-
nosign = b.umax(i32_ty, nosign, b.i32_val(0x2000));
569-
nosign = b.umin(i32_ty, nosign, b.i32_val(0x5c00));
570-
nosign = b.sub(i32_ty, nosign, b.i32_val(0x2000));
571-
nosign = b.shl(i32_ty, nosign, b.i32_val(1));
568+
normal = b.umax(i16_ty, normal, b.i16_val(0x2000));
569+
normal = b.umin(i16_ty, normal, b.i16_val(0x5f00));
570+
normal = b.sub(i16_ty, normal, b.i16_val(0x2000));
571+
normal = b.shl(i16_ty, normal, b.i16_val(1));
572572

573573
// Choose between subnormal and normal values.
574-
nosign = b.select(is_subnormal, subnormal, nosign);
575-
576-
Value res_val = b.or_(i32_ty, nosign, sign);
577-
auto fp8x4VecTy = vec_ty(i8_ty, 4);
574+
Value res_val = b.select(is_subnormal, subnormal, normal);
575+
res_val = b.or_(i16_ty, res_val, sign);
576+
auto fp8x4VecTy = vec_ty(i8_ty, 2);
578577
Value res = b.bitcast(res_val, fp8x4VecTy);
579578

580-
return {b.extract_element(i8_ty, res, b.i32_val(1))};
579+
return {b.extract_element(i8_ty, res, b.i16_val(1))};
581580
}
582581

583582
static SmallVector<Value> Fp8E4M3Nv_to_Bf16(Location loc,

0 commit comments

Comments
 (0)