Skip to content

Commit b36c35e

Browse files
committed
make mxfpScaleBf16 private
1 parent 90f937a commit b36c35e

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ using namespace mlir::triton::gpu;
1717

1818
namespace {
1919

20+
static Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc,
21+
Value v, Value scale) {
22+
Value vBf16 = bitcast(v, bf16_ty);
23+
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
24+
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
25+
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
26+
27+
Value v0 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, vBf16);
28+
Value v1 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, scaleBf16);
29+
auto result = rewriter.create<LLVM::FMulOp>(loc, f32_ty, v0, v1);
30+
auto undefRounding = static_cast<mlir::triton::RoundingMode>(-1);
31+
Value scaledBf16 = mlir::triton::intel::convertFp32ToBf16(
32+
loc, rewriter, result, undefRounding);
33+
// Value scaledBf16 = fmul(vBf16, scaleBf16);
34+
// Account for NaN in the scale as per the mxfp specification.
35+
return select(scaleIsNan, nanBf16, scaledBf16);
36+
};
37+
2038
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
2139
private:
2240
const TargetInfoBase &targetInfo;
@@ -48,8 +66,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
4866

4967
for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
5068
for (int j = 0; j < 32; ++j) {
51-
xVals[32 * i + j] = LLVM::intel::mxfpScaleBf16(
52-
rewriter, loc, xVals[32 * i + j], scaleVal);
69+
xVals[32 * i + j] =
70+
mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], scaleVal);
5371
}
5472
}
5573

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,4 @@ LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) {
159159
return printFunc;
160160
}
161161

162-
Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc, Value v,
163-
Value scale) {
164-
Value vBf16 = bitcast(v, bf16_ty);
165-
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
166-
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
167-
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
168-
169-
Value v0 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, vBf16);
170-
Value v1 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, scaleBf16);
171-
auto result = rewriter.create<LLVM::FMulOp>(loc, f32_ty, v0, v1);
172-
auto undefRounding = static_cast<mlir::triton::RoundingMode>(-1);
173-
Value scaledBf16 = mlir::triton::intel::convertFp32ToBf16(
174-
loc, rewriter, result, undefRounding);
175-
// Value scaledBf16 = fmul(vBf16, scaleBf16);
176-
// Account for NaN in the scale as per the mxfp specification.
177-
return select(scaleIsNan, nanBf16, scaledBf16);
178-
};
179162
} // namespace mlir::LLVM::intel

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ static Value getModuleWarpSize(RewriterBase &rewriter, Location loc) {
127127
return i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
128128
}
129129

130-
Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc, Value v,
131-
Value scale);
132130
} // namespace mlir::LLVM::intel
133131

134132
// -----------------------------------------------------------------------

0 commit comments

Comments
 (0)