Skip to content

Commit 3bcc576

Browse files
authored
Enable BF16 casts for non-SPIRV targets. (#3992)
Signed-off-by: Ilya Enkovich <[email protected]>
1 parent b6f7f3b commit 3bcc576

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "llvm/ADT/TypeSwitch.h"
1111

1212
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
13+
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
1314

1415
using namespace mlir;
1516

@@ -83,23 +84,29 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8384
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
8485
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
8586
getSupportBF16ConversionAttrName())) {
86-
constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL";
87-
Type inTy = getTypeWithSameShape(v.getType(), i16_ty);
88-
Type outTy = getTypeWithSameShape(inTy, f32_ty);
89-
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
90-
91-
auto bitcastValue = b.bitcast(v, inTy).getResult();
92-
93-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
94-
/*other=*/LLVM::ModRefInfo::NoModRef,
95-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
96-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
97-
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
98-
funcAttrs.memEffectsAttr = memAttr;
99-
100-
auto call = gpu::intel::createDeviceFunctionCall(
101-
rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
102-
return call.getResult();
87+
// For SPIRV target, use specialized intrinsic call for conversion.
88+
// Otherwise, use fpext operation.
89+
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
90+
constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL";
91+
Type inTy = getTypeWithSameShape(v.getType(), i16_ty);
92+
Type outTy = getTypeWithSameShape(inTy, f32_ty);
93+
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
94+
95+
auto bitcastValue = b.bitcast(v, inTy).getResult();
96+
97+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
98+
/*other=*/LLVM::ModRefInfo::NoModRef,
99+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
100+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
101+
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
102+
funcAttrs.memEffectsAttr = memAttr;
103+
104+
auto call = gpu::intel::createDeviceFunctionCall(
105+
rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
106+
return call.getResult();
107+
}
108+
109+
return rewriter.create<LLVM::FPExtOp>(loc, f32_ty, v);
103110
}
104111
}
105112

@@ -118,22 +125,27 @@ Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
118125
getSupportBF16ConversionAttrName()) &&
119126
rounding == RoundingMode::RTNE) {
120127
// Intel SPIR-V extension only supports round-to-nearest-even
121-
constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL";
122-
Type inTy = v.getType();
123-
Type funcOutTy = getTypeWithSameShape(inTy, i16_ty);
124-
Type outTy = getTypeWithSameShape(inTy, bf16_ty);
125-
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
126-
127-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
128-
/*other=*/LLVM::ModRefInfo::NoModRef,
129-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
130-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
131-
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
132-
funcAttrs.memEffectsAttr = memAttr;
133-
134-
auto call = gpu::intel::createDeviceFunctionCall(
135-
rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
136-
return b.bitcast(call.getResult(), outTy);
128+
// LLVM fptrunc operation also assumes round-to-nearest mode
129+
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
130+
constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL";
131+
Type inTy = v.getType();
132+
Type funcOutTy = getTypeWithSameShape(inTy, i16_ty);
133+
Type outTy = getTypeWithSameShape(inTy, bf16_ty);
134+
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
135+
136+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
137+
/*other=*/LLVM::ModRefInfo::NoModRef,
138+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
139+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
140+
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
141+
funcAttrs.memEffectsAttr = memAttr;
142+
143+
auto call = gpu::intel::createDeviceFunctionCall(
144+
rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
145+
return b.bitcast(call.getResult(), outTy);
146+
}
147+
148+
return rewriter.create<LLVM::FPTruncOp>(loc, bf16_ty, v);
137149
}
138150
}
139151

0 commit comments

Comments
 (0)