@@ -84,9 +84,8 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8484 auto result = convertWithFunctionCall (
8585 b, as_int16, " __spirv_ConvertBF16ToFINTEL" , i16_ty, f32_ty,
8686 TritonIntelGPUDialect::getSupportBF16ConversionAttrName ());
87- if (result) {
87+ if (result)
8888 return result;
89- }
9089
9190 auto as_int32 = b.zext (i32_ty, as_int16);
9291 auto shifted = b.shl (i32_ty, as_int32, b.i32_val (16 ));
@@ -96,11 +95,19 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
9695Value convertFp32ToBf16 (Location loc, ConversionPatternRewriter &rewriter,
9796 Value v, RoundingMode rounding) {
9897 TritonLLVMIRRewriter b (loc, rewriter);
99- auto result = convertWithFunctionCall (
100- b, v, " __spirv_ConvertFToBF16INTEL" , f32_ty, i16_ty,
101- TritonIntelGPUDialect::getSupportBF16ConversionAttrName ());
102- if (result) {
103- return b.bitcast (result, bf16_ty);
98+ // Intel SPIR-V extension only supports round-to-nearest-even
99+ // LLVM fptrunc operation also assumes round-to-nearest mode
100+ if (rounding == RoundingMode::RTNE) {
101+ std::string attrName = " __spirv_ConvertFToBF16INTEL" ;
102+ auto result = convertWithFunctionCall (
103+ b, v, attrName, f32_ty, i16_ty,
104+ TritonIntelGPUDialect::getSupportBF16ConversionAttrName ());
105+ if (result)
106+ return b.bitcast (result, bf16_ty);
107+
108+ auto op = v.getDefiningOp ();
109+ if (mlir::LLVM::intel::hasModuleAttr (op, attrName))
110+ return LLVM::FPTruncOp::create (rewriter, loc, bf16_ty, v);
104111 }
105112
106113 assert (!isa<VectorType>(v.getType ()) && " Not yet supported" );
0 commit comments