Skip to content

Commit 3962235

Browse files
Add convertWithFunctionCall utility
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8bbfd21 commit 3962235

File tree

3 files changed

+45
-60
lines changed

3 files changed

+45
-60
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp

Lines changed: 14 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include "Dialect/TritonIntelGPU/Transforms/Utility.h"
44
#include "Utils/LLVMIntr.h"
5-
#include "Utils/Mangling.h"
65

76
#include "mlir/Dialect/Arith/IR/Arith.h"
87
#include "mlir/Support/LLVM.h"
@@ -13,6 +12,7 @@
1312
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
1413

1514
using namespace mlir;
15+
using namespace triton::gpu::intel;
1616

1717
namespace {
1818
static bool isBF16OrTensorOf(Type type) {
@@ -79,74 +79,28 @@ struct TruncBF16 : ConvertOpToLLVMPattern<arith::TruncFOp> {
7979
namespace mlir::triton::intel {
8080
Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8181
Value v) {
82-
auto b = TritonLLVMOpBuilder(loc, rewriter);
83-
if (auto definingOp = v.getDefiningOp()) {
84-
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
85-
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
86-
getSupportBF16ConversionAttrName())) {
87-
// For SPIRV target, use specialized intrinsic call for conversion.
88-
// Otherwise, use fpext operation.
89-
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
90-
constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL";
91-
Type inTy = getTypeWithSameShape(v.getType(), i16_ty);
92-
Type outTy = getTypeWithSameShape(inTy, f32_ty);
93-
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
94-
95-
auto bitcastValue = b.bitcast(v, inTy).getResult();
96-
97-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
98-
/*other=*/LLVM::ModRefInfo::NoModRef,
99-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
100-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
101-
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
102-
funcAttrs.memEffectsAttr = memAttr;
103-
104-
auto call = gpu::intel::createDeviceFunctionCall(
105-
rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
106-
return call.getResult();
107-
}
108-
109-
return LLVM::FPExtOp::create(rewriter, loc, f32_ty, v);
110-
}
82+
TritonLLVMIRRewriter b(loc, rewriter);
83+
auto as_int16 = b.bitcast(v, i16_ty).getResult();
84+
auto result = convertWithFunctionCall(
85+
b, as_int16, "__spirv_ConvertBF16ToFINTEL", i16_ty, f32_ty,
86+
TritonIntelGPUDialect::getSupportBF16ConversionAttrName());
87+
if (result) {
88+
return result;
11189
}
11290

113-
auto as_int16 = b.bitcast(v, i16_ty);
11491
auto as_int32 = b.zext(i32_ty, as_int16);
11592
auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16));
11693
return (b.bitcast(shifted, f32_ty));
11794
}
11895

11996
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
12097
Value v, RoundingMode rounding) {
121-
auto b = TritonLLVMOpBuilder(loc, rewriter);
122-
if (auto definingOp = v.getDefiningOp()) {
123-
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
124-
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
125-
getSupportBF16ConversionAttrName()) &&
126-
rounding == RoundingMode::RTNE) {
127-
// Intel SPIR-V extension only supports round-to-nearest-even
128-
// LLVM fptrunc operation also assumes round-to-nearest mode
129-
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
130-
constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL";
131-
Type inTy = v.getType();
132-
Type funcOutTy = getTypeWithSameShape(inTy, i16_ty);
133-
Type outTy = getTypeWithSameShape(inTy, bf16_ty);
134-
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
135-
136-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
137-
/*other=*/LLVM::ModRefInfo::NoModRef,
138-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
139-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
140-
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
141-
funcAttrs.memEffectsAttr = memAttr;
142-
143-
auto call = gpu::intel::createDeviceFunctionCall(
144-
rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
145-
return b.bitcast(call.getResult(), outTy);
146-
}
147-
148-
return LLVM::FPTruncOp::create(rewriter, loc, bf16_ty, v);
149-
}
98+
TritonLLVMIRRewriter b(loc, rewriter);
99+
auto result = convertWithFunctionCall(
100+
b, v, "__spirv_ConvertFToBF16INTEL", f32_ty, i16_ty,
101+
TritonIntelGPUDialect::getSupportBF16ConversionAttrName());
102+
if (result) {
103+
return b.bitcast(result, bf16_ty);
150104
}
151105

152106
assert(!isa<VectorType>(v.getType()) && "Not yet supported");

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,30 @@ convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding) {
169169
}
170170

171171
} // namespace mlir::LLVM::intel
172+
173+
namespace mlir::triton::intel {
174+
Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value,
175+
StringRef baseName, Type inType, Type outType,
176+
StringRef hasAttrName) {
177+
auto op = value.getDefiningOp();
178+
if (!gpu::intel::hasSpirvTargetArch(op))
179+
return {};
180+
if (!hasAttrName.empty() &&
181+
!mlir::LLVM::intel::hasModuleAttr(op, hasAttrName))
182+
return {};
183+
184+
auto valueType = value.getType();
185+
Type inTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, inType);
186+
Type outTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, outType);
187+
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
188+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
189+
/*other=*/LLVM::ModRefInfo::NoModRef,
190+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
191+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
192+
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
193+
funcAttrs.memEffectsAttr = memAttr;
194+
return gpu::intel::createDeviceFunctionCall(rewriter, funcName, outTy, {inTy},
195+
{value}, {}, funcAttrs)
196+
.getResult();
197+
}
198+
} // namespace mlir::triton::intel

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8888
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
8989
Value v, RoundingMode rounding);
9090

91+
Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value,
92+
StringRef baseName, Type inType, Type outType,
93+
StringRef hasAttrName = {});
94+
9195
} // namespace mlir::triton::intel
9296

9397
#endif // TRITON_CONVERSION_TRITONINTELGPU_TO_LLVM_UTILITY_H

0 commit comments

Comments
 (0)