Skip to content

Commit 431cc25

Browse files
Add convertWithFunctionCall utility (#5626)
Signed-off-by: Whitney Tsang <[email protected]> Co-authored-by: Andrey Pavlenko <[email protected]>
1 parent 1c46fef commit 431cc25

File tree

3 files changed

+76
-60
lines changed

3 files changed

+76
-60
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp

Lines changed: 21 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include "Dialect/TritonIntelGPU/Transforms/Utility.h"
44
#include "Utils/LLVMIntr.h"
5-
#include "Utils/Mangling.h"
65

76
#include "mlir/Dialect/Arith/IR/Arith.h"
87
#include "mlir/Support/LLVM.h"
@@ -13,6 +12,7 @@
1312
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
1413

1514
using namespace mlir;
15+
using namespace triton::gpu::intel;
1616

1717
namespace {
1818
static bool isBF16OrTensorOf(Type type) {
@@ -79,74 +79,35 @@ struct TruncBF16 : ConvertOpToLLVMPattern<arith::TruncFOp> {
7979
namespace mlir::triton::intel {
8080
Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8181
Value v) {
82-
auto b = TritonLLVMOpBuilder(loc, rewriter);
83-
if (auto definingOp = v.getDefiningOp()) {
84-
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
85-
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
86-
getSupportBF16ConversionAttrName())) {
87-
// For SPIRV target, use specialized intrinsic call for conversion.
88-
// Otherwise, use fpext operation.
89-
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
90-
constexpr StringLiteral baseName = "__spirv_ConvertBF16ToFINTEL";
91-
Type inTy = getTypeWithSameShape(v.getType(), i16_ty);
92-
Type outTy = getTypeWithSameShape(inTy, f32_ty);
93-
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
94-
95-
auto bitcastValue = b.bitcast(v, inTy).getResult();
96-
97-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
98-
/*other=*/LLVM::ModRefInfo::NoModRef,
99-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
100-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
101-
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
102-
funcAttrs.memEffectsAttr = memAttr;
103-
104-
auto call = gpu::intel::createDeviceFunctionCall(
105-
rewriter, funcName, outTy, {inTy}, {bitcastValue}, {}, funcAttrs);
106-
return call.getResult();
107-
}
108-
109-
return LLVM::FPExtOp::create(rewriter, loc, f32_ty, v);
110-
}
111-
}
82+
TritonLLVMIRRewriter b(loc, rewriter);
83+
auto as_int16 = b.bitcast(v, i16_ty).getResult();
84+
auto result = convertWithFunctionCall(
85+
b, as_int16, "__spirv_ConvertBF16ToFINTEL", i16_ty, f32_ty,
86+
TritonIntelGPUDialect::getSupportBF16ConversionAttrName());
87+
if (result)
88+
return result;
11289

113-
auto as_int16 = b.bitcast(v, i16_ty);
11490
auto as_int32 = b.zext(i32_ty, as_int16);
11591
auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16));
11692
return (b.bitcast(shifted, f32_ty));
11793
}
11894

11995
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
12096
Value v, RoundingMode rounding) {
121-
auto b = TritonLLVMOpBuilder(loc, rewriter);
122-
if (auto definingOp = v.getDefiningOp()) {
123-
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
124-
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
125-
getSupportBF16ConversionAttrName()) &&
126-
rounding == RoundingMode::RTNE) {
127-
// Intel SPIR-V extension only supports round-to-nearest-even
128-
// LLVM fptrunc operation also assumes round-to-nearest mode
129-
if (gpu::intel::hasSpirvTargetArch(moduleOp)) {
130-
constexpr StringLiteral baseName = "__spirv_ConvertFToBF16INTEL";
131-
Type inTy = v.getType();
132-
Type funcOutTy = getTypeWithSameShape(inTy, i16_ty);
133-
Type outTy = getTypeWithSameShape(inTy, bf16_ty);
134-
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
135-
136-
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
137-
/*other=*/LLVM::ModRefInfo::NoModRef,
138-
/*argMem=*/LLVM::ModRefInfo::NoModRef,
139-
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
140-
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
141-
funcAttrs.memEffectsAttr = memAttr;
142-
143-
auto call = gpu::intel::createDeviceFunctionCall(
144-
rewriter, funcName, funcOutTy, {inTy}, {v}, {}, funcAttrs);
145-
return b.bitcast(call.getResult(), outTy);
146-
}
147-
97+
TritonLLVMIRRewriter b(loc, rewriter);
98+
// Intel SPIR-V extension only supports round-to-nearest-even
99+
// LLVM fptrunc operation also assumes round-to-nearest mode
100+
if (rounding == RoundingMode::RTNE) {
101+
std::string attrName = "__spirv_ConvertFToBF16INTEL";
102+
auto result = convertWithFunctionCall(
103+
b, v, attrName, f32_ty, i16_ty,
104+
TritonIntelGPUDialect::getSupportBF16ConversionAttrName());
105+
if (result)
106+
return b.bitcast(result, bf16_ty);
107+
108+
auto op = v.getDefiningOp();
109+
if (mlir::LLVM::intel::hasModuleAttr(op, attrName))
148110
return LLVM::FPTruncOp::create(rewriter, loc, bf16_ty, v);
149-
}
150111
}
151112

152113
assert(!isa<VectorType>(v.getType()) && "Not yet supported");

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "Utility.h"
10+
#include "Utils/LLVMIntr.h"
11+
#include "Utils/Mangling.h"
12+
13+
#include "llvm/ADT/TypeSwitch.h"
1014

1115
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
1216

17+
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
18+
1319
using namespace mlir;
1420
using namespace mlir::triton;
1521

@@ -168,4 +174,45 @@ convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding) {
168174
}
169175
}
170176

177+
Type getTypeWithSameShape(Type type, Type elementType) {
178+
return TypeSwitch<Type, Type>(type)
179+
.Case([elementType](VectorType type) {
180+
return VectorType::get(type.getShape(), elementType,
181+
type.getScalableDims());
182+
})
183+
.Default(elementType);
184+
}
185+
186+
bool hasModuleAttr(Operation *op, StringRef attrName) {
187+
auto mod = op->getParentOfType<ModuleOp>();
188+
return mod && mod->hasAttr(attrName);
189+
}
190+
171191
} // namespace mlir::LLVM::intel
192+
193+
namespace mlir::triton::intel {
194+
Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value,
195+
StringRef baseName, Type inType, Type outType,
196+
StringRef hasAttrName) {
197+
auto op = value.getDefiningOp();
198+
if (!gpu::intel::hasSpirvTargetArch(op))
199+
return {};
200+
if (!hasAttrName.empty() &&
201+
!mlir::LLVM::intel::hasModuleAttr(op, hasAttrName))
202+
return {};
203+
204+
auto valueType = value.getType();
205+
Type inTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, inType);
206+
Type outTy = mlir::LLVM::intel::getTypeWithSameShape(valueType, outType);
207+
std::string funcName = mlir::triton::gpu::intel::mangle(baseName, inTy);
208+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
209+
/*other=*/LLVM::ModRefInfo::NoModRef,
210+
/*argMem=*/LLVM::ModRefInfo::NoModRef,
211+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
212+
auto funcAttrs = gpu::intel::noUnwindWillReturnAttrs;
213+
funcAttrs.memEffectsAttr = memAttr;
214+
return gpu::intel::createDeviceFunctionCall(rewriter, funcName, outTy, {inTy},
215+
{value}, {}, funcAttrs)
216+
.getResult();
217+
}
218+
} // namespace mlir::triton::intel

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ Block &createPredicatedBlock(RewriterBase &rewriter, Location loc, Value cond,
7979
LLVM::RoundingMode
8080
convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding);
8181

82+
Type getTypeWithSameShape(Type type, Type elementType);
83+
84+
bool hasModuleAttr(Operation *op, StringRef attrName);
85+
8286
} // namespace mlir::LLVM::intel
8387

8488
namespace mlir::triton::intel {
@@ -88,6 +92,10 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8892
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
8993
Value v, RoundingMode rounding);
9094

95+
Value convertWithFunctionCall(TritonLLVMIRRewriter &rewriter, Value value,
96+
StringRef baseName, Type inType, Type outType,
97+
StringRef hasAttrName = {});
98+
9199
} // namespace mlir::triton::intel
92100

93101
#endif // TRITON_CONVERSION_TRITONINTELGPU_TO_LLVM_UTILITY_H

0 commit comments

Comments
 (0)