Skip to content

Commit 48feaa4

Browse files
committed
Merge branch 'liyang/upcast_mxfp_and_dot_scaled' of https://github.com/intel/intel-xpu-backend-for-triton into liyang/upcast_mxfp_and_dot_scaled
2 parents bc32cd2 + b36c35e commit 48feaa4

File tree

8 files changed

+55
-48
lines changed

8 files changed

+55
-48
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
152152
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
153153
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
154154
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
155-
newShape.back() *= 2;
156155
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
157-
sizePerThread.back() *= 2;
156+
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
157+
sizePerThread[!opIdx] *= 2;
158+
newShape[!opIdx] *= 2;
158159
newVEncoding = BlockedEncodingAttr::get(
159160
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
160161
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),

python/test/unit/language/test_core.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3440,9 +3440,6 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
34403440
pytest.skip(f"scaled_dot({normal_type}, {mxfp_type}) only implemented for MI300")
34413441
if mma == 16 and K == 64:
34423442
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
3443-
if is_xpu():
3444-
if rhs_scale:
3445-
pytest.skip("scaled_dot with rhs_scale not supported on XPU")
34463443

34473444
@triton.jit
34483445
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,

third_party/intel/include/Analysis/DPAS.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class DPASAnalysis {
2424
FP32_FP32_TF32_TF32,
2525
FP16_FP16_FP16_FP16,
2626
BF16_BF16_BF16_BF16,
27+
U32_U32_U8_U8,
28+
S32_S32_S8_S8,
2729
// data types for dot scaled.
2830
FP32_FP32_BF16_FP8,
2931
FP32_FP32_BF16_FP4,
@@ -32,8 +34,6 @@ class DPASAnalysis {
3234
FP32_FP32_FP8_FP4,
3335
FP32_FP32_FP4_BF16,
3436
FP32_FP32_FP4_FP8,
35-
U32_U32_U8_U8,
36-
S32_S32_S8_S8,
3737
NOT_APPLICABLE
3838
};
3939

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ DPASAnalysis::getDPASType(OpTy op) {
150150
if (aElemTy.isBF16() &&
151151
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
152152
return DPASEngineType::FP32_FP32_BF16_FP8;
153-
if (aElemTy.isBF16() && bElemTy.isFloat4E2M1FN())
153+
// 2 E2M1 are packed into 1 int8
154+
if (aElemTy.isBF16() && bElemTy.isInteger(8))
154155
return DPASEngineType::FP32_FP32_BF16_FP4;
155156
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
156157
bElemTy.isBF16())
@@ -159,9 +160,8 @@ DPASAnalysis::getDPASType(OpTy op) {
159160
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
160161
return DPASEngineType::FP32_FP32_FP8_FP8;
161162
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
162-
bElemTy.isFloat4E2M1FN())
163+
bElemTy.isInteger(8))
163164
return DPASEngineType::FP32_FP32_FP8_FP4;
164-
// 2 E2M1 are packed into 1 int8
165165
if (aElemTy.isInteger(8) && bElemTy.isBF16())
166166
return DPASEngineType::FP32_FP32_FP4_BF16;
167167
if (aElemTy.isInteger(8) &&

third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ using namespace mlir::triton::gpu;
1717

1818
namespace {
1919

20+
static Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc,
21+
Value v, Value scale) {
22+
Value vBf16 = bitcast(v, bf16_ty);
23+
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
24+
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
25+
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
26+
27+
Value v0 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, vBf16);
28+
Value v1 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, scaleBf16);
29+
auto result = rewriter.create<LLVM::FMulOp>(loc, f32_ty, v0, v1);
30+
auto undefRounding = static_cast<mlir::triton::RoundingMode>(-1);
31+
Value scaledBf16 = mlir::triton::intel::convertFp32ToBf16(
32+
loc, rewriter, result, undefRounding);
33+
// Value scaledBf16 = fmul(vBf16, scaleBf16);
34+
// Account for NaN in the scale as per the mxfp specification.
35+
return select(scaleIsNan, nanBf16, scaledBf16);
36+
};
37+
2038
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
2139
private:
2240
const TargetInfoBase &targetInfo;
@@ -48,8 +66,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
4866

4967
for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
5068
for (int j = 0; j < 32; ++j) {
51-
xVals[32 * i + j] = LLVM::intel::mxfpScaleBf16(
52-
rewriter, loc, xVals[32 * i + j], scaleVal);
69+
xVals[32 * i + j] =
70+
mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], scaleVal);
5371
}
5472
}
5573

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,4 @@ LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) {
159159
return printFunc;
160160
}
161161

162-
Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc, Value v,
163-
Value scale) {
164-
Value vBf16 = bitcast(v, bf16_ty);
165-
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
166-
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
167-
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
168-
169-
Value v0 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, vBf16);
170-
Value v1 = mlir::triton::intel::convertBf16ToFp32(loc, rewriter, scaleBf16);
171-
auto result = rewriter.create<LLVM::FMulOp>(loc, f32_ty, v0, v1);
172-
auto undefRounding = static_cast<mlir::triton::RoundingMode>(-1);
173-
Value scaledBf16 = mlir::triton::intel::convertFp32ToBf16(
174-
loc, rewriter, result, undefRounding);
175-
// Value scaledBf16 = fmul(vBf16, scaleBf16);
176-
// Account for NaN in the scale as per the mxfp specification.
177-
return select(scaleIsNan, nanBf16, scaledBf16);
178-
};
179162
} // namespace mlir::LLVM::intel

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ static Value getModuleWarpSize(RewriterBase &rewriter, Location loc) {
127127
return i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
128128
}
129129

130-
Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc, Value v,
131-
Value scale);
132130
} // namespace mlir::LLVM::intel
133131

134132
// -----------------------------------------------------------------------

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,16 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
291291
static_assert(opIdx == 0 || opIdx == 1, "Illegal operand index");
292292
assert(opDesc.scale && "Expecting valid operand & scale");
293293

294-
unsigned opsPerChannel = dpasEnc.getOpsPerChannel();
295-
296294
MLIRContext *ctx = opDesc.op.getContext();
295+
unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
296+
unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
297+
unsigned opsPerChannel = dpasEnc.getOpsPerChannel();
297298
unsigned rank = retType.getRank();
299+
298300
if (upcastMXFPUseDotOpEnc) {
299301
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
300302
opsPerChannel *= 2;
303+
301304
auto opEncoding = ttg::intel::DpasEncodingAttr::get(
302305
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
303306
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
@@ -312,7 +315,6 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
312315
unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1];
313316
SmallVector<unsigned, 2> threadsPerWarp{instrShapeM,
314317
warpSize / instrShapeM};
315-
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
316318
SmallVector<unsigned, 2> warpsPerCTA(rank, 1);
317319
warpsPerCTA[0] = numWarps;
318320
auto CTALayout = ttg::getCTALayout(retType.getEncoding());
@@ -325,7 +327,6 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
325327
return createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
326328
}
327329

328-
// Temporary code: remove once upcast_mxfp support dot encoding.
329330
auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
330331
opDesc.scale.getType().getEncoding());
331332
assert(scaleEncoding && "Expecting blocked encoding for scale");
@@ -334,19 +335,28 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
334335
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
335336
// the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
336337
unsigned scalingBlockSize = 32;
337-
// 2 FP4E2M1 are packed in 1 I8
338+
// 2 FP4E2M1 are packed in one i8
338339
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
339340
scalingBlockSize = 16;
340-
SmallVector<unsigned, 2> sizePerThread(rank, 1);
341-
sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
342-
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
343-
ctx, sizePerThread, scaleEncoding.getThreadsPerWarp(),
344-
scaleEncoding.getWarpsPerCTA(), scaleEncoding.getCTAOrder(),
345-
scaleEncoding.getCTALayout());
346341

342+
SmallVector<unsigned> sizePerThread = {1, 1};
343+
SmallVector<unsigned> threadsPerWarp = {1, 1};
344+
sizePerThread[!opIdx] = scalingBlockSize;
345+
threadsPerWarp[opIdx] = warpSize;
346+
SmallVector<unsigned> warpsPerCTA = {numWarps, 1};
347+
348+
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
349+
ctx, sizePerThread, threadsPerWarp, warpsPerCTA,
350+
scaleEncoding.getCTAOrder(), scaleEncoding.getCTALayout());
347351
TensorValue op =
348352
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
349-
TensorValue scale = opDesc.scale;
353+
354+
warpsPerCTA = opIdx ? SmallVector<unsigned>{1, numWarps}
355+
: SmallVector<unsigned>{numWarps, 1};
356+
auto newScaleEncoding = ttg::BlockedEncodingAttr::get(
357+
ctx, {1, 1}, {warpSize, 1}, warpsPerCTA, scaleEncoding.getCTAOrder(),
358+
scaleEncoding.getCTALayout());
359+
TensorValue scale = createScale(opDesc.scale, newScaleEncoding, rewriter);
350360

351361
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
352362
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
@@ -357,11 +367,11 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
357367

358368
auto upcastOp = createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
359369

360-
auto upcastRetType = cast<RankedTensorType>(upcastOp.getType());
361-
retType = RankedTensorType::get(retType.getShape(),
362-
retType.getElementType(), retDotOpEncoding);
363-
return rewriter.create<ttg::ConvertLayoutOp>(opDesc.op.getLoc(),
364-
upcastRetType, upcastOp);
370+
auto resultType = cast<RankedTensorType>(upcastOp.getType());
371+
resultType = RankedTensorType::get(
372+
resultType.getShape(), resultType.getElementType(), retDotOpEncoding);
373+
return rewriter.create<ttg::ConvertLayoutOp>(opDesc.op.getLoc(), resultType,
374+
upcastOp);
365375
}
366376

367377
template <unsigned opIdx>

0 commit comments

Comments
 (0)