1010#include " llvm/ADT/TypeSwitch.h"
1111
1212#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
13+ #include " intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
1314
1415using namespace mlir ;
1516
@@ -83,23 +84,29 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8384 auto moduleOp = definingOp->getParentWithTrait <OpTrait::SymbolTable>();
8485 if (moduleOp->hasAttr (triton::gpu::intel::TritonIntelGPUDialect::
8586 getSupportBF16ConversionAttrName ())) {
86- constexpr StringLiteral baseName = " __spirv_ConvertBF16ToFINTEL" ;
87- Type inTy = getTypeWithSameShape (v.getType (), i16_ty);
88- Type outTy = getTypeWithSameShape (inTy, f32_ty);
89- std::string funcName = mlir::triton::gpu::intel::mangle (baseName, inTy);
90-
91- auto bitcastValue = b.bitcast (v, inTy).getResult ();
92-
93- auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
94- /* other=*/ LLVM::ModRefInfo::NoModRef,
95- /* argMem=*/ LLVM::ModRefInfo::NoModRef,
96- /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
97- auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
98- funcAttrs.memEffectsAttr = memAttr;
99-
100- auto call = gpu::intel::createDeviceFunctionCall (
101- rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
102- return call.getResult ();
87+ // For SPIRV target, use specialized intrinsic call for conversion.
88+ // Otherwise, use fpext operation.
89+ if (gpu::intel::hasSpirvTargetArch (moduleOp)) {
90+ constexpr StringLiteral baseName = " __spirv_ConvertBF16ToFINTEL" ;
91+ Type inTy = getTypeWithSameShape (v.getType (), i16_ty);
92+ Type outTy = getTypeWithSameShape (inTy, f32_ty);
93+ std::string funcName = mlir::triton::gpu::intel::mangle (baseName, inTy);
94+
95+ auto bitcastValue = b.bitcast (v, inTy).getResult ();
96+
97+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
98+ /* other=*/ LLVM::ModRefInfo::NoModRef,
99+ /* argMem=*/ LLVM::ModRefInfo::NoModRef,
100+ /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
101+ auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
102+ funcAttrs.memEffectsAttr = memAttr;
103+
104+ auto call = gpu::intel::createDeviceFunctionCall (
105+ rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
106+ return call.getResult ();
107+ }
108+
109+ return rewriter.create <LLVM::FPExtOp>(loc, f32_ty, v);
103110 }
104111 }
105112
@@ -118,22 +125,27 @@ Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
118125 getSupportBF16ConversionAttrName ()) &&
119126 rounding == RoundingMode::RTNE) {
120127 // Intel SPIR-V extension only supports round-to-nearest-even
121- constexpr StringLiteral baseName = " __spirv_ConvertFToBF16INTEL" ;
122- Type inTy = v.getType ();
123- Type funcOutTy = getTypeWithSameShape (inTy, i16_ty);
124- Type outTy = getTypeWithSameShape (inTy, bf16_ty);
125- std::string funcName = mlir::triton::gpu::intel::mangle (baseName, inTy);
126-
127- auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
128- /* other=*/ LLVM::ModRefInfo::NoModRef,
129- /* argMem=*/ LLVM::ModRefInfo::NoModRef,
130- /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
131- auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
132- funcAttrs.memEffectsAttr = memAttr;
133-
134- auto call = gpu::intel::createDeviceFunctionCall (
135- rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
136- return b.bitcast (call.getResult (), outTy);
128+ // LLVM fptrunc operation also assumes round-to-nearest mode
129+ if (gpu::intel::hasSpirvTargetArch (moduleOp)) {
130+ constexpr StringLiteral baseName = " __spirv_ConvertFToBF16INTEL" ;
131+ Type inTy = v.getType ();
132+ Type funcOutTy = getTypeWithSameShape (inTy, i16_ty);
133+ Type outTy = getTypeWithSameShape (inTy, bf16_ty);
134+ std::string funcName = mlir::triton::gpu::intel::mangle (baseName, inTy);
135+
136+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
137+ /* other=*/ LLVM::ModRefInfo::NoModRef,
138+ /* argMem=*/ LLVM::ModRefInfo::NoModRef,
139+ /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
140+ auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
141+ funcAttrs.memEffectsAttr = memAttr;
142+
143+ auto call = gpu::intel::createDeviceFunctionCall (
144+ rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
145+ return b.bitcast (call.getResult (), outTy);
146+ }
147+
148+ return rewriter.create <LLVM::FPTruncOp>(loc, bf16_ty, v);
137149 }
138150 }
139151
0 commit comments