@@ -527,57 +527,56 @@ static SmallVector<Value>
527527Fp16_to_Fp8E4M3Nv_RTNE (Location loc, ConversionPatternRewriter &rewriter,
528528 const SmallVector<Value> &v) {
529529 auto b = TritonLLVMOpBuilder (loc, rewriter);
530- Value val = b.zext (i32_ty, b. bitcast (v[0 ], i16_ty) );
531- Value sign = b.and_ (i32_ty , val, b.i32_val (0x8000 ));
532- Value nosign = b.and_ (i32_ty , val, b.i32_val (0x7fff ));
530+ Value val = b.bitcast (v[0 ], i16_ty);
531+ Value sign = b.and_ (i16_ty , val, b.i16_val (0x8000 ));
532+ Value nosign = b.and_ (i16_ty , val, b.i16_val (0x7fff ));
533533
534- Value exp = b.and_ (i32_ty , b.lshr (nosign, b.i32_val (10 )), b.i32_val (0x1f ));
534+ Value exp = b.and_ (i16_ty , b.lshr (nosign, b.i16_val (10 )), b.i16_val (0x1f ));
535535 // Check if we need a translation to a subnormal value. This happens when
536536 // exp value is in range [5, 8].
537537 Value is_subnormal =
538- b.and_ (b.icmp_uge (exp, b.i32_val (5 )), b.icmp_ule (exp, b.i32_val (8 )));
539- Value shift = b.sub (i32_ty , b.i32_val (8 ), exp);
540- Value subnormal = b.and_ (i32_ty , nosign, b.i32_val (0x3ff ));
541- subnormal = b.or_ (i32_ty , subnormal, b.i32_val (0x400 ));
538+ b.and_ (b.icmp_uge (exp, b.i16_val (5 )), b.icmp_ule (exp, b.i16_val (8 )));
539+ Value shift = b.sub (i16_ty , b.i16_val (8 ), exp);
540+ Value subnormal = b.and_ (i16_ty , nosign, b.i16_val (0x3ff ));
541+ subnormal = b.or_ (i16_ty , subnormal, b.i16_val (0x400 ));
542542 // Make rounding with respect to bits we are going to shift and cut off.
543- Value round_step = b.shl (i32_ty , b.i32_val (0x100 ), shift);
544- Value tail_mask = b.sub (i32_ty , round_step, b.i32_val (1 ));
545- Value tail = b.and_ (i32_ty , subnormal, tail_mask);
546- Value threshold = b.shl (i32_ty , b.i32_val (0x80 ), shift);
543+ Value round_step = b.shl (i16_ty , b.i16_val (0x100 ), shift);
544+ Value tail_mask = b.sub (i16_ty , round_step, b.i16_val (1 ));
545+ Value tail = b.and_ (i16_ty , subnormal, tail_mask);
546+ Value threshold = b.shl (i16_ty , b.i16_val (0x80 ), shift);
547547 Value odd_truncated =
548- b.icmp_ne (b.and_ (i32_ty , subnormal, round_step), b.i32_val (0 ));
548+ b.icmp_ne (b.and_ (i16_ty , subnormal, round_step), b.i16_val (0 ));
549549 Value round_up = b.or_ (b.icmp_ugt (tail, threshold),
550550 b.and_ (b.icmp_eq (tail, threshold), odd_truncated));
551551 subnormal =
552- b.select (round_up, b.add (i32_ty , subnormal, round_step), subnormal);
552+ b.select (round_up, b.add (i16_ty , subnormal, round_step), subnormal);
553553 // Now shift to get the final result.
554- subnormal = b.lshr (i32_ty , subnormal, shift);
554+ subnormal = b.lshr (i16_ty , subnormal, shift);
555555
556556 // Normalized case. Start with rounding, then apply exp range to fit 4 bits,
557557 // adjust bias and shift left.
558558 // TODO: NaN values might be mishandled.
559- tail = b.and_ (i32_ty , nosign, b.i32_val (0x7f ));
559+ tail = b.and_ (i16_ty , nosign, b.i16_val (0x7f ));
560560 odd_truncated =
561- b.icmp_ne (b.and_ (i32_ty , nosign, b.i32_val (0x80 )), b.i32_val (0 ));
562- round_up = b.or_ (b.icmp_ugt (tail, b.i32_val (0x40 )),
563- b.and_ (b.icmp_eq (tail, b.i32_val (0x40 )), odd_truncated));
561+ b.icmp_ne (b.and_ (i16_ty , nosign, b.i16_val (0x80 )), b.i16_val (0 ));
562+ round_up = b.or_ (b.icmp_ugt (tail, b.i16_val (0x40 )),
563+ b.and_ (b.icmp_eq (tail, b.i16_val (0x40 )), odd_truncated));
564564 Value rounded =
565- b.and_ (i32_ty , b.add (i32_ty , nosign, b.i32_val (0x80 )), b.i32_val (0x7f80 ));
566- nosign = b.select (round_up, rounded, nosign);
565+ b.and_ (i16_ty , b.add (i16_ty , nosign, b.i16_val (0x80 )), b.i16_val (0x7f80 ));
566+ Value normal = b.select (round_up, rounded, nosign);
567567
568- nosign = b.umax (i32_ty, nosign , b.i32_val (0x2000 ));
569- nosign = b.umin (i32_ty, nosign , b.i32_val ( 0x5c00 ));
570- nosign = b.sub (i32_ty, nosign , b.i32_val (0x2000 ));
571- nosign = b.shl (i32_ty, nosign , b.i32_val (1 ));
568+ normal = b.umax (i16_ty, normal , b.i16_val (0x2000 ));
569+ normal = b.umin (i16_ty, normal , b.i16_val ( 0x5f00 ));
570+ normal = b.sub (i16_ty, normal , b.i16_val (0x2000 ));
571+ normal = b.shl (i16_ty, normal , b.i16_val (1 ));
572572
573573 // Choose between subnormal and normal values.
574- nosign = b.select (is_subnormal, subnormal, nosign);
575-
576- Value res_val = b.or_ (i32_ty, nosign, sign);
577- auto fp8x4VecTy = vec_ty (i8_ty, 4 );
574+ Value res_val = b.select (is_subnormal, subnormal, normal);
575+ res_val = b.or_ (i16_ty, res_val, sign);
576+ auto fp8x4VecTy = vec_ty (i8_ty, 2 );
578577 Value res = b.bitcast (res_val, fp8x4VecTy);
579578
580- return {b.extract_element (i8_ty, res, b.i32_val (1 ))};
579+ return {b.extract_element (i8_ty, res, b.i16_val (1 ))};
581580}
582581
583582static SmallVector<Value> Fp8E4M3Nv_to_Bf16 (Location loc,
0 commit comments