@@ -433,36 +433,70 @@ static Value convertBf16ToFp32(Location loc,
433433 return bitcast (shifted, f32_ty);
434434}
435435
436+ static Value buildGCNInstruction (Location loc, RewriterBase &rewritter,
437+ StringRef instrName,
438+ ArrayRef<StringRef> constraints,
439+ ArrayRef<Value> vals, Type retType) {
440+ assert (constraints.size () == vals.size () + 1 );
441+ assert (vals.size () == 2 || vals.size () == 3 );
442+ GCNBuilder builder;
443+ GCNInstr &instr = *builder.create (instrName.str ());
444+ GCNBuilder::Operand *out = builder.newOperand (constraints[0 ]);
445+ SmallVector<GCNBuilder::Operand *> operands;
446+ for (int i = 0 ; i < vals.size (); ++i) {
447+ operands.push_back (builder.newOperand (vals[i], constraints[i + 1 ]));
448+ }
449+
450+ if (vals.size () == 2 ) {
451+ instr (out, operands[0 ], operands[1 ]);
452+ } else {
453+ instr (out, operands[0 ], operands[1 ], operands[2 ]);
454+ }
455+
456+ return builder.launch (rewritter, loc, retType, false );
457+ }
458+
436459static Value convertFp32ToBf16 (Location loc,
437460 ConversionPatternRewriter &rewriter,
438461 const Value &v, const RoundingMode rounding) {
462+ auto as_int32 = bitcast (v, i32_ty);
439463 if (rounding == RoundingMode::RTZ) {
440- auto as_int32 = bitcast (v, i32_ty);
441464 auto shifted = lshr (i32_ty, as_int32, i32_val (16 ));
442465 auto truncated = trunc (i16_ty, shifted);
443466 return bitcast (truncated, bf16_ty);
444467 }
445- // Otherwise it is (rounding == RoundingMode::RTNE)
446- auto as_uint32 = bitcast (v, i32_ty);
447- auto check_exponent =
448- and_ (i32_ty, xor_ (i32_ty, as_uint32, i32_val (0xffffffff )),
449- i32_val (0x7f800000 ));
450- auto exponent_not_all1s = icmp_ne (check_exponent, i32_val (0 ));
451- auto exponent_all1s = icmp_eq (check_exponent, i32_val (0 ));
452- auto rounded =
453- add (i32_ty, i32_val (0x7fff ),
454- and_ (i32_ty, lshr (i32_ty, as_uint32, i32_val (16 )), i32_val (1 )));
455- rounded = add (i32_ty, rounded, as_uint32);
456- auto res = select (exponent_not_all1s, rounded, as_uint32);
457-
458- auto preserve_nan =
459- and_ (i1_ty, exponent_all1s,
460- icmp_ne (and_ (i32_ty, as_uint32, i32_val (0xffff )), i32_val (0 )));
461- auto nan = or_ (i32_ty, as_uint32, i32_val (0x10000 ));
462- res = select (preserve_nan, nan, res);
463-
464- auto shifted = lshr (i32_ty, res, i32_val (16 ));
465- auto truncated = trunc (i16_ty, shifted);
468+
469+ // This implementation is a faster version for fp32 to bf16 type conversion
470+ // It is from CK:
471+ // https://github.com/cgmillette/composable_kernel/commit/24e75bef6aa5
472+ // It uses less VGPR and less number of instructions compared to the
473+ // previous implementation
474+ SmallVector<StringRef> constraints0 = {" =s" , " v" , " v" };
475+ SmallVector<Value> vals0 = {v, v};
476+ Value isNan = buildGCNInstruction (loc, rewriter, " v_cmp_u_f32" , constraints0,
477+ vals0, i64_ty);
478+
479+ Value v16 = i32_val (16 );
480+ Value v1 = i32_val (1 );
481+ SmallVector<StringRef> constraints1 = {" =v" , " v" , " v" , " v" };
482+ SmallVector<Value> vals1 = {v, v16, v1};
483+ Value tmp = buildGCNInstruction (loc, rewriter, " v_bfe_u32" , constraints1,
484+ vals1, i32_ty);
485+
486+ SmallVector<StringRef> constraints2 = {" =v" , " v" , " v" , " v" };
487+ Value v7FFF = i32_val (0x7FFF );
488+ SmallVector<Value> vals2 = {v, tmp, v7FFF};
489+ Value tmp1 = buildGCNInstruction (loc, rewriter, " v_add3_u32" , constraints2,
490+ vals2, i32_ty);
491+
492+ SmallVector<StringRef> constraints3 = {" =v" , " v" , " v" , " s" };
493+ Value vNan = i32_val (0x7FFF0000 );
494+ SmallVector<Value> vals3 = {tmp1, vNan, isNan};
495+ Value cndMask = buildGCNInstruction (loc, rewriter, " v_cndmask_b32" ,
496+ constraints3, vals3, i32_ty);
497+
498+ Value shifted = lshr (i32_ty, cndMask, v16);
499+ Value truncated = trunc (i16_ty, shifted);
466500 return bitcast (truncated, bf16_ty);
467501}
468502
0 commit comments