Skip to content

Commit ef77825

Browse files
authored
Support DecomposeScaledBlocked to Fp4ToFpOp (#3606)
This PR is splitted as the first part of #3538. It decomposes `tt.scaled_dot` to `tt.dot` + `tt.fp_to_fp` and `tt.fp4_to_fp`.
1 parent 6fa2562 commit ef77825

File tree

7 files changed

+334
-40
lines changed

7 files changed

+334
-40
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4949
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5050
"TRITON_INTEL_RAISE_BLOCK_POINTER",
5151
"TRITON_INTEL_REDUCE_TRANSPOSE",
52+
"TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED",
5253
// clang-format on
5354
};
5455

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 46 additions & 38 deletions
Large diffs are not rendered by default.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
3+
namespace mlir::triton::gpu::intel {
4+
5+
void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns,
6+
int benefit);
7+
8+
} // namespace mlir::triton::gpu::intel

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,7 +1047,7 @@ struct TritonIntelGPUInferLayoutInterface
10471047
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
10481048
Attribute &outEnc, bool fwdInference,
10491049
std::optional<Location> loc) const override {
1050-
// TODO
1050+
// Not required to support Fp4ToFpOp on DPAS layout.
10511051
return failure();
10521052
}
10531053
};

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Dialect/TritonIntelGPU/IR/Attributes.h"
2+
#include "Dialect/TritonIntelGPU/Transforms/DecomposeScaledBlocked.h"
23
#include "Dialect/TritonIntelGPU/Transforms/Utility.h"
34
#include "mlir/Analysis/SliceAnalysis.h"
45
#include "mlir/IR/Builders.h"
@@ -13,6 +14,7 @@
1314
#include "triton/Dialect/Triton/IR/Dialect.h"
1415
#include "triton/Dialect/Triton/IR/Utility.h"
1516
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
17+
#include "triton/Tools/Sys/GetEnv.hpp"
1618
#include "llvm/ADT/TypeSwitch.h"
1719
#include <optional>
1820

@@ -669,7 +671,17 @@ class TritonIntelGPUAccelerateMatmulPass
669671
transposeDots(m);
670672

671673
RewritePatternSet patterns(context);
672-
patterns.add<BlockedToDPAS, DecomposeScaledBlocked>(context, dpasAnalysis);
674+
// TODO: This ENV variable will be removed in the Fp4ToFp lowering PR
675+
// Keep it here to maintain old implementation functionality.
676+
if (!mlir::triton::tools::getBoolEnv(
677+
"TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED"))
678+
patterns.add<BlockedToDPAS, DecomposeScaledBlocked>(context,
679+
dpasAnalysis);
680+
else {
681+
constexpr int benefitDefault = 1;
682+
patterns.add<BlockedToDPAS>(context, dpasAnalysis);
683+
ttgi::populateDecomposeScaledBlockedPatterns(patterns, benefitDefault);
684+
}
673685
if (applyPatternsGreedily(m, std::move(patterns)).failed())
674686
signalPassFailure();
675687

third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(TritonIntelGPUTransforms
22
AccelerateMatmul.cpp
33
Coalesce.cpp
4+
DecomposeScaledBlocked.cpp
45
DistributeToWarps.cpp
56
MatchTargetSize.cpp
67
MaterializeBlockPointer.cpp
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#include "Dialect/TritonIntelGPU/Transforms/DecomposeScaledBlocked.h"
2+
3+
#include "mlir/IR/Types.h"
4+
#include "mlir/IR/Value.h"
5+
#include "mlir/Support/LogicalResult.h"
6+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
7+
8+
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
10+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::triton;
14+
using namespace mlir::triton::gpu;
15+
16+
namespace {
17+
18+
SmallVector<int, 2> getTransposeOrder(int rank) {
19+
assert(rank >= 2);
20+
auto transOrder = llvm::to_vector<2>(llvm::seq<int>(rank - 2));
21+
transOrder.push_back(rank - 1);
22+
transOrder.push_back(rank - 2);
23+
return transOrder;
24+
}
25+
26+
class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
27+
28+
public:
29+
DecomposeScaledBlocked(MLIRContext *context, int benefit)
30+
: OpRewritePattern<DotScaledOp>(context, benefit) {}
31+
32+
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
33+
PatternRewriter &rewriter) const override {
34+
// Types
35+
auto computeType = getComputeType(scaledDotOp.getAElemType(),
36+
scaledDotOp.getBElemType(), rewriter);
37+
auto loc = scaledDotOp.getLoc();
38+
39+
auto cvtDotOperand = [&](TypedValue<RankedTensorType> v,
40+
int opIdx) -> TypedValue<RankedTensorType> {
41+
auto *ctx = rewriter.getContext();
42+
auto retEnc = scaledDotOp.getType().getEncoding();
43+
auto vType = v.getType();
44+
auto encoding = DotOperandEncodingAttr::get(ctx, opIdx, retEnc,
45+
vType.getElementType());
46+
auto retTy = RankedTensorType::get(vType.getShape(),
47+
vType.getElementType(), encoding);
48+
return rewriter.create<ConvertLayoutOp>(loc, retTy, v);
49+
};
50+
51+
auto scaledA = scaleArg(rewriter, scaledDotOp, 0, computeType);
52+
scaledA = cvtDotOperand(scaledA, 0);
53+
auto scaledB = scaleArg(rewriter, scaledDotOp, 1, computeType);
54+
scaledB = cvtDotOperand(scaledB, 1);
55+
auto newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), scaledA, scaledB,
56+
scaledDotOp.getC());
57+
58+
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(scaledDotOp,
59+
scaledDotOp.getType(), newDot);
60+
return success();
61+
}
62+
63+
private:
64+
FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType,
65+
PatternRewriter &rewriter) const {
66+
if (aType == ScaleDotElemType::FP16 || bType == ScaleDotElemType::FP16)
67+
return rewriter.getF16Type();
68+
return rewriter.getBF16Type();
69+
}
70+
71+
TypedValue<RankedTensorType> scaleTo16(PatternRewriter &rewriter,
72+
TypedValue<RankedTensorType> scale,
73+
FloatType computeType) const {
74+
auto loc = scale.getLoc();
75+
auto scaleTy = scale.getType();
76+
assert(computeType == rewriter.getBF16Type() ||
77+
computeType == rewriter.getF16Type());
78+
79+
// Choose an fp type that can fit the scale value.
80+
FloatType largeFpType = computeType == rewriter.getF16Type()
81+
? rewriter.getF32Type()
82+
: computeType;
83+
int intWidth = largeFpType.getIntOrFloatBitWidth();
84+
auto intType = rewriter.getIntegerType(intWidth);
85+
86+
auto zexted =
87+
rewriter.create<arith::ExtUIOp>(loc, scaleTy.clone(intType), scale);
88+
// getFpMantissaWidth() returns the number of bits in the mantissa plus the
89+
// sign bit!
90+
int shiftValue = largeFpType.getFPMantissaWidth() - 1;
91+
auto shiftConst =
92+
rewriter.create<arith::ConstantIntOp>(loc, shiftValue, intWidth);
93+
auto shift =
94+
rewriter.create<SplatOp>(loc, scaleTy.clone(intType), shiftConst);
95+
auto shlRes = rewriter.create<arith::ShLIOp>(loc, zexted, shift);
96+
Value scaleFP =
97+
rewriter.create<BitcastOp>(loc, scaleTy.clone(largeFpType), shlRes);
98+
if (largeFpType != computeType) {
99+
scaleFP = rewriter.create<arith::TruncFOp>(
100+
loc, scaleTy.clone(computeType), scaleFP);
101+
}
102+
return cast<TypedValue<RankedTensorType>>(scaleFP);
103+
}
104+
105+
TypedValue<RankedTensorType>
106+
broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp,
107+
ModuleOp mod, TypedValue<RankedTensorType> scale,
108+
int dim) const {
109+
auto *ctx = rewriter.getContext();
110+
auto loc = scale.getLoc();
111+
auto scaleTy = scale.getType();
112+
auto rank = scaleTy.getRank();
113+
// 2.1) Expand dims along the last dimension
114+
{
115+
// 2.1.1) Find default encoding for ExpandDims
116+
auto shape = to_vector(scaleTy.getShape());
117+
shape.insert(shape.end(), 1);
118+
auto nWarps = lookupNumWarps(scaledDotOp);
119+
auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
120+
auto numCTAs = TritonGPUDialect::getNumCTAs(mod);
121+
auto blockedEnc = getDefaultBlockedEncoding(ctx, shape, nWarps,
122+
threadsPerWarp, numCTAs);
123+
// 2.1.2) Cast scale16 to SliceEncoding
124+
auto sliceEnc = SliceEncodingAttr::get(ctx, rank, blockedEnc);
125+
auto sliceType = RankedTensorType::get(
126+
scaleTy.getShape(), scaleTy.getElementType(), sliceEnc);
127+
scale = rewriter.create<ConvertLayoutOp>(loc, sliceType, scale);
128+
}
129+
auto expandScale = rewriter.create<ExpandDimsOp>(loc, scale, rank);
130+
// 2.2) Broadcast the dimension to size 32
131+
auto scaleShape = to_vector(scaleTy.getShape());
132+
scaleShape.push_back(32);
133+
auto broadcastScale = rewriter.create<BroadcastOp>(
134+
loc, expandScale.getType().clone(scaleShape), expandScale);
135+
// 2.3) Transpose the dimension to the scaled dimension
136+
auto transposeOrder = llvm::to_vector(llvm::seq<int32_t>(rank));
137+
transposeOrder.insert(transposeOrder.begin() + dim + 1, rank);
138+
auto transposedScale =
139+
rewriter.create<TransOp>(loc, broadcastScale, transposeOrder);
140+
// 2.4) Reshape to the shape of v
141+
scaleShape.pop_back();
142+
scaleShape[dim] *= 32;
143+
auto reshapeScale =
144+
rewriter.create<ReshapeOp>(loc, scaleShape, transposedScale);
145+
return reshapeScale;
146+
}
147+
148+
TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter,
149+
DotScaledOp scaledDotOp, ModuleOp mod,
150+
TypedValue<RankedTensorType> mxfp,
151+
TypedValue<RankedTensorType> scale,
152+
FloatType computeType, int dim) const {
153+
// Implement tl.where(scale == 0xFF, float("nan"), mxfp)
154+
auto loc = scale.getLoc();
155+
156+
// FIXME: use large int type (int32) for comparing with 0xFF to avoid
157+
// accidently masking non-NaN values to NaN.
158+
// This piece of code will be removed after
159+
// https://github.com/intel/intel-xpu-backend-for-triton/issues/3605
160+
FloatType largeFpType = computeType == rewriter.getF16Type()
161+
? rewriter.getF32Type()
162+
: computeType;
163+
int intWidth = largeFpType.getIntOrFloatBitWidth();
164+
auto intType = rewriter.getIntegerType(intWidth);
165+
// Use large int scale type, incase it get nonNaN to NaN
166+
auto scaleTy = scale.getType().clone(intType);
167+
auto zexted = rewriter.create<arith::ExtUIOp>(loc, scaleTy, scale);
168+
169+
// Scale is NaN
170+
auto constFF = rewriter.create<arith::ConstantOp>(
171+
loc, scaleTy,
172+
DenseElementsAttr::get(scaleTy,
173+
APInt(scaleTy.getElementTypeBitWidth(), 0xff)));
174+
auto scaleIsNan = cast<TypedValue<RankedTensorType>>(
175+
rewriter
176+
.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, zexted,
177+
constFF)
178+
.getResult());
179+
auto cond = broadcastScale(rewriter, scaledDotOp, mod, scaleIsNan, dim);
180+
// Make scale is NaN compatible with mxfp
181+
auto condTy = cond.getType();
182+
condTy = RankedTensorType::get(condTy.getShape(), condTy.getElementType(),
183+
mxfp.getType().getEncoding());
184+
cond = rewriter.create<ConvertLayoutOp>(loc, condTy, cond);
185+
186+
// Create NaN
187+
auto mxfpTy = mxfp.getType();
188+
auto nan = APFloat::getNaN(
189+
cast<FloatType>(mxfpTy.getElementType()).getFloatSemantics());
190+
auto constNan = rewriter.create<arith::ConstantOp>(
191+
loc, mxfpTy, DenseElementsAttr::get(mxfpTy, nan));
192+
193+
auto result = rewriter.create<arith::SelectOp>(loc, cond, constNan, mxfp);
194+
return cast<TypedValue<RankedTensorType>>(result.getResult());
195+
}
196+
197+
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
198+
DotScaledOp scaledDotOp, int opIdx,
199+
FloatType computeType) const {
200+
auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB();
201+
auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale();
202+
auto isFp4 =
203+
ScaleDotElemType::E2M1 ==
204+
(opIdx == 0 ? scaledDotOp.getAElemType() : scaledDotOp.getBElemType());
205+
auto fastMath = scaledDotOp.getFastMath();
206+
207+
auto *ctx = rewriter.getContext();
208+
auto loc = v.getLoc();
209+
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
210+
auto rank = v.getType().getRank();
211+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
212+
213+
// 0) Upcast value to computeType (fp16/bf16)
214+
if (isFp4) {
215+
// We always pack along the fastest moving dimension, kDim
216+
v = rewriter.create<Fp4ToFpOp>(loc, v, computeType, kDim);
217+
} else {
218+
auto vType16 = v.getType().clone(computeType);
219+
v = cast<TypedValue<RankedTensorType>>(
220+
rewriter.create<FpToFpOp>(loc, vType16, v).getResult());
221+
}
222+
if (!scale)
223+
return v;
224+
225+
// For some weird reason, we take the scale with shape as if it were coming
226+
// from the lhs even when it's the rhs. In a normal world, we should accept
227+
// this parametre transposed, as we do with the mxfp.
228+
if (opIdx == 1) {
229+
auto order = getTransposeOrder(rank);
230+
scale = rewriter.create<TransOp>(loc, scale, order);
231+
}
232+
233+
// 1) Cast scale to compute type (fp16/bf16)
234+
auto scale16 = scaleTo16(rewriter, scale, computeType);
235+
236+
// 2) Broadcast scale to the same shape and layout as v
237+
auto reshapeScale =
238+
broadcastScale(rewriter, scaledDotOp, mod, scale16, kDim);
239+
reshapeScale =
240+
rewriter.create<ConvertLayoutOp>(loc, v.getType(), reshapeScale);
241+
242+
// 3) Multiply
243+
auto mxfp = cast<TypedValue<RankedTensorType>>(
244+
rewriter.create<arith::MulFOp>(loc, v, reshapeScale).getResult());
245+
246+
// Skip NaN checks if fastMath
247+
if (fastMath)
248+
return mxfp;
249+
250+
// 4) If the scale is NaN, return NaN, else return the scaled value.
251+
return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, computeType, kDim);
252+
}
253+
};
254+
255+
} // namespace
256+
257+
namespace mlir::triton::gpu::intel {
258+
259+
void populateDecomposeScaledBlockedPatterns(RewritePatternSet &patterns,
260+
int benefit) {
261+
patterns.add<DecomposeScaledBlocked>(patterns.getContext(), benefit);
262+
}
263+
264+
} // namespace mlir::triton::gpu::intel

0 commit comments

Comments
 (0)