Skip to content

Commit 580fda2

Browse files
authored
Lowering Fp4ToFP to LLVM (#3607)
This PR is splitted as the second part of #3538. It lowers `Fp4ToFPOp` to LLVM and remove `UpcastMXFPOp`. CI depends on #3606.
1 parent f1613cb commit 580fda2

File tree

10 files changed

+136
-666
lines changed

10 files changed

+136
-666
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4949
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5050
"TRITON_INTEL_RAISE_BLOCK_POINTER",
5151
"TRITON_INTEL_REDUCE_TRANSPOSE",
52-
"TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED",
5352
// clang-format on
5453
};
5554

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -202,33 +202,4 @@ def TTIG_SubGroupTransposeOp
202202
let hasVerifier = 1;
203203
}
204204

205-
// The same as ttg.upcast_mxfp, but we want Dot Layout from Dpas layout for input tensor
206-
def TTIG_UpcastMXFPOp : TTIG_Op<"upcast_mxfp", [Pure]> {
207-
let summary = "Convert an mxfp tensor to bf16/fp16";
208-
209-
let hasVerifier = 1;
210-
211-
let description = [{
212-
Compute the bf16 encoded in the given mxfp number as per
213-
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
214-
}];
215-
let arguments = (
216-
ins
217-
TT_Tensor:$src,
218-
TT_Tensor:$scale,
219-
TT_ScaleDotElemTypeAttr:$fp_type,
220-
BoolAttr:$fastMath
221-
);
222-
let results = (outs TT_Tensor:$result);
223-
224-
let assemblyFormat = [{
225-
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
226-
}];
227-
228-
let extraClassDeclaration = [{
229-
static RankedTensorType deduceOutputType(
230-
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
231-
}];
232-
}
233-
234205
#endif

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp

Lines changed: 0 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -209,129 +209,4 @@ LogicalResult SubGroupTransposeOp::verify() {
209209
return success();
210210
}
211211

212-
LogicalResult UpcastMXFPOp::verify() {
213-
auto fpType = getFpType();
214-
215-
auto xTy = getSrc().getType();
216-
auto scaleTy = getScale().getType();
217-
Builder b(getContext());
218-
if (xTy.getElementType() != b.getBF16Type() &&
219-
xTy.getElementType() != b.getF16Type() &&
220-
xTy.getElementType() != b.getI8Type()) {
221-
return emitOpError(
222-
"element type of the first operand must be bf16/fp16 or i8");
223-
}
224-
225-
if (scaleTy.getElementType() != b.getI8Type()) {
226-
return emitOpError("element type of the second operand must be uint8");
227-
}
228-
229-
auto xShape = xTy.getShape();
230-
auto scaleShape = scaleTy.getShape();
231-
232-
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
233-
return emitOpError(
234-
"operands must have the same number of dimensions, at least 2");
235-
}
236-
237-
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
238-
fpType == ScaleDotElemType::E5M2)) {
239-
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
240-
}
241-
242-
auto layoutX = xTy.getEncoding();
243-
auto layoutScale = scaleTy.getEncoding();
244-
if (bool(layoutX) != bool(layoutScale)) {
245-
return emitOpError(
246-
"Expected either both or neither operands to have an encoding");
247-
}
248-
// Nothing to check if no encoding. This is used to infer the return type in
249-
// AccelerateMatmul.cpp
250-
if (!layoutX) {
251-
return success();
252-
}
253-
254-
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
255-
if (!dotEncoding) {
256-
return emitOpError("Expected a DotOperandEncodingAttr for values");
257-
}
258-
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
259-
return emitOpError(
260-
"Expected a BlockOperandEncoding or LinearOperandEncoding "
261-
"for scales");
262-
}
263-
264-
// Change to support fp8 types
265-
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
266-
// Figure out the K dimension for the input A/B. For A/B scale, the K
267-
// dimension is always the last dimension.
268-
const int opIdx = dotEncoding.getOpIdx();
269-
const bool hasBatch = xShape.size() == 3;
270-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
271-
272-
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
273-
return emitOpError("K dimension of first operand must be 16 times "
274-
"larger than last/K dimension of the second operand");
275-
}
276-
277-
// Check other dimensions match too. For input A/B, we need to figure out the
278-
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
279-
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
280-
if (hasBatch && xShape[0] != scaleShape[0])
281-
return emitOpError("batch dimension must match between operands");
282-
if (xShape[mnIdx] != scaleShape[hasBatch]) {
283-
return emitOpError("M/N dimension must match between operands");
284-
}
285-
286-
return success();
287-
}
288-
289-
RankedTensorType
290-
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
291-
ScaleDotElemType inputElemType,
292-
Type outputElemType) {
293-
MLIRContext *ctx = inputTensor.getContext();
294-
auto xTy = inputTensor.getType();
295-
if (inputElemType != ScaleDotElemType::E2M1)
296-
return xTy;
297-
298-
auto xShape = xTy.getShape();
299-
auto newShape = llvm::to_vector(xShape);
300-
auto encoding = xTy.getEncoding();
301-
if (!encoding) {
302-
newShape.back() *= 2;
303-
return RankedTensorType::get(xShape, outputElemType);
304-
}
305-
306-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
307-
const int opIdx = oldEncoding.getOpIdx();
308-
// Note: For Intel the dot operands layout's kWidth parameter must match
309-
// the parent's DPAS layout opsPerChannel so we need to materialize a
310-
// new DPAS layout.
311-
auto dpasEncoding = cast<intel::DpasEncodingAttr>(oldEncoding.getParent());
312-
unsigned opsPerChannel =
313-
intel::DpasEncodingAttr::getOpsPerChannel(outputElemType);
314-
// e2m1 is packed 2 elements per int8, we must handle continuous 2
315-
// elements when upcasting to bf16
316-
if (xTy.getElementType() == IntegerType::get(ctx, 8))
317-
opsPerChannel *= 2;
318-
auto newDpasEncoding = intel::DpasEncodingAttr::get(
319-
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
320-
dpasEncoding.getExecutionSize(), opsPerChannel,
321-
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
322-
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
323-
324-
// Operand A is packed to i16 for scalar type < 16 bits.
325-
int kWidth =
326-
(opIdx == 0) && (opsPerChannel != 1) ? opsPerChannel / 2 : opsPerChannel;
327-
328-
Attribute newVEncoding =
329-
DotOperandEncodingAttr::get(ctx, opIdx, newDpasEncoding, kWidth);
330-
331-
const bool hasBatch = xShape.size() == 3;
332-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
333-
newShape[kIdx] *= 2;
334-
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
335-
}
336-
337212
} // namespace mlir::triton::gpu::intel

third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_triton_library(TritonIntelGPUToLLVM
88
DotOpToLLVM/FMA.cpp
99
DotOpToLLVM.cpp
1010
ElementwiseOpToLLVM.cpp
11+
Fp4ToFpOpToLLVM.cpp
1112
HistogramOpToLLVM.cpp
1213
LoadStoreOpToLLVM.cpp
1314
MakeRangeOpToLLVM.cpp
@@ -20,7 +21,6 @@ add_triton_library(TritonIntelGPUToLLVM
2021
TritonGPUToLLVM.cpp
2122
TritonOpsToLLVM.cpp
2223
TypeConverter.cpp
23-
UpcastMXFPToLLVM.cpp
2424
Utility.cpp
2525

2626
DEPENDS
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#include "PatternTritonGPUOpToLLVM.h"
2+
3+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
4+
#include "mlir/IR/BuiltinOps.h"
5+
#include "mlir/IR/TypeUtilities.h"
6+
#include "mlir/IR/ValueRange.h"
7+
#include "mlir/Transforms/DialectConversion.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
9+
#include "triton/Dialect/Triton/IR/Dialect.h"
10+
#include "llvm/ADT/STLExtras.h"
11+
#include "llvm/ADT/SmallVector.h"
12+
13+
using namespace mlir;
14+
using namespace mlir::triton;
15+
using namespace mlir::triton::gpu;
16+
using namespace mlir::triton::gpu::intel;
17+
18+
namespace {
19+
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
20+
ArrayRef<Value> values) {
21+
auto b = TritonLLVMOpBuilder(loc, rewriter);
22+
SmallVector<Value> results;
23+
for (auto v : values) {
24+
auto em0 = b.and_(v, b.i8_val(0x7));
25+
auto em1 = b.and_(v, b.i8_val(0x70));
26+
Value v0 =
27+
b.or_(b.shl(b.zext(i16_ty, em0), b.i16_val(6)),
28+
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x8))), b.i16_val(12)));
29+
Value v1 =
30+
b.or_(b.shl(b.zext(i16_ty, em1), b.i16_val(2)),
31+
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x80))), b.i16_val(8)));
32+
// Three cases:
33+
// 1) x is normal and non-zero: Correct bias
34+
v0 = b.select(b.icmp_ne(b.and_(em0, b.i8_val(0x6)), b.i8_val(0)),
35+
b.add(v0, b.i16_val((127 - 1) << 7)), v0);
36+
v1 = b.select(b.icmp_ne(b.and_(em1, b.i8_val(0x60)), b.i8_val(0)),
37+
b.add(v1, b.i16_val((127 - 1) << 7)), v1);
38+
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
39+
// bf16
40+
v0 = b.bitcast(
41+
b.select(b.icmp_eq(em0, b.i8_val(0x1)),
42+
b.or_(b.i16_val(16128), b.and_(v0, b.i16_val(0x8000))), v0),
43+
bf16_ty);
44+
v1 = b.bitcast(
45+
b.select(b.icmp_eq(em1, b.i8_val(0x10)),
46+
b.or_(b.i16_val(16128), b.and_(v1, b.i16_val(0x8000))), v1),
47+
bf16_ty);
48+
// 3) x is zero, nothing to do
49+
results.push_back(v0);
50+
results.push_back(v1);
51+
}
52+
return results;
53+
}
54+
55+
SmallVector<Value> convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc,
56+
ArrayRef<Value> values) {
57+
auto b = TritonLLVMOpBuilder(loc, rewriter);
58+
SmallVector<Value> results;
59+
for (auto v : values) {
60+
auto em0 = b.and_(v, b.i8_val(0x7));
61+
auto em1 = b.and_(v, b.i8_val(0x70));
62+
// FP16 bits: sign = 1, exponent = 5, mantissa = 10
63+
Value v0 =
64+
b.or_(b.shl(b.zext(i16_ty, em0), b.i16_val(10 - 1)),
65+
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x8))), b.i16_val(12)));
66+
Value v1 =
67+
b.or_(b.shl(b.zext(i16_ty, em1), b.i16_val(10 - 1 - 4)),
68+
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x80))), b.i16_val(8)));
69+
70+
// Three cases:
71+
// 1) x is normal and non-zero: Correct bias
72+
v0 = b.select(b.icmp_ne(b.and_(em0, b.i8_val(0x6)), b.i8_val(0)),
73+
b.add(v0, b.i16_val((15 - 1) << 10)), v0);
74+
v1 = b.select(b.icmp_ne(b.and_(em1, b.i8_val(0x60)), b.i8_val(0)),
75+
b.add(v1, b.i16_val((15 - 1) << 10)), v1);
76+
77+
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
78+
v0 = b.bitcast(
79+
b.select(b.icmp_eq(em0, b.i8_val(0x1)),
80+
b.or_(b.i16_val(0x3800), b.and_(v0, b.i16_val(0x8000))), v0),
81+
f16_ty);
82+
v1 = b.bitcast(
83+
b.select(b.icmp_eq(em1, b.i8_val(0x10)),
84+
b.or_(b.i16_val(0x3800), b.and_(v1, b.i16_val(0x8000))), v1),
85+
f16_ty);
86+
// 3) x is zero, nothing to do
87+
results.push_back(v0);
88+
results.push_back(v1);
89+
}
90+
return results;
91+
}
92+
93+
class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
94+
public:
95+
Fp4ToFpOpPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit)
96+
: ConvertOpToLLVMPattern<Fp4ToFpOp>(typeConverter, benefit) {}
97+
98+
LogicalResult
99+
matchAndRewrite(Fp4ToFpOp op, OpAdaptor adaptor,
100+
ConversionPatternRewriter &rewriter) const override {
101+
Location loc = op.getLoc();
102+
auto *ctx = op.getContext();
103+
Type elemType = op.getType().getElementType();
104+
assert(elemType == f16_ty || elemType == bf16_ty);
105+
bool toFp16 = elemType == f16_ty;
106+
107+
SmallVector<Value> xVals =
108+
unpackLLElements(loc, adaptor.getSrc(), rewriter);
109+
xVals = toFp16 ? convertMxfp4x2ToFp16x2(rewriter, loc, xVals)
110+
: convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
111+
112+
Value result =
113+
packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType());
114+
rewriter.replaceOp(op, result);
115+
return success();
116+
}
117+
};
118+
} // anonymous namespace
119+
120+
void mlir::triton::intel::populateFp4ToFpToLLVMPatterns(
121+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
122+
PatternBenefit benefit) {
123+
patterns.add<Fp4ToFpOpPattern>(typeConverter, benefit);
124+
}

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ void populateElementwiseOpToLLVMPatterns(
3838
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
3939
PatternBenefit benefit);
4040

41-
void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter,
42-
RewritePatternSet &patterns,
43-
const TargetInfo &targetInfo,
44-
PatternBenefit benefit);
41+
void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
42+
RewritePatternSet &patterns,
43+
PatternBenefit benefit);
4544

4645
void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter,
4746
RewritePatternSet &patterns,

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ class TritonGPUToLLVMPipelineManager {
278278
targetInfo, benefit);
279279
intel::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo,
280280
patterns, benefit);
281-
intel::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns,
282-
targetInfo, benefit);
281+
intel::populateFp4ToFpToLLVMPatterns(typeConverter, patterns, benefit);
283282
}
284283

285284
intel::populateSPMDOpToLLVMPattern(typeConverter, patterns, targetInfo,

0 commit comments

Comments
 (0)