Skip to content

Commit 0a58d8e

Browse files
committed
Fix e2m1
1 parent cfabc03 commit 0a58d8e

File tree

4 files changed

+55
-35
lines changed

4 files changed

+55
-35
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/Triton/IR/Utility.h"
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
#include "triton/Tools/Sys/GetEnv.hpp"
78

89
#define GET_OP_CLASSES
910
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
@@ -109,6 +110,8 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
109110
auto xShape = xTy.getShape();
110111

111112
auto encoding = xTy.getEncoding();
113+
bool upcastMXFPUseDotOpEnc =
114+
mlir::triton::tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING");
112115

113116
if (typeEncoded == ScaleDotElemType::E2M1) {
114117
RankedTensorType retTy;
@@ -118,34 +121,47 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
118121
newShape.back() *= 2;
119122
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
120123
} else {
121-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
122-
123-
const int opIdx = oldEncoding.getOpIdx();
124-
const bool hasBatch = xShape.size() == 3;
125-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
126-
newShape[kIdx] *= 2;
127124
Type elemType = FloatType::getBF16(ctx);
128-
129-
// Note: For Intel the dot operands layout's kWidth parameter must match
130-
// the parent's DPAS layout opsPerChannel so we need to materialize a new
131-
// DPAS layout.
132125
Attribute newVEncoding;
133-
if (auto dpasEncoding =
134-
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
135-
auto newDpasEncoding = intel::DpasEncodingAttr::get(
136-
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
137-
dpasEncoding.getExecutionSize(),
138-
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
139-
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
140-
dpasEncoding.getSubGroupSize());
141-
newVEncoding = DotOperandEncodingAttr::get(
142-
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
126+
if (upcastMXFPUseDotOpEnc) {
127+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
128+
129+
const int opIdx = oldEncoding.getOpIdx();
130+
const bool hasBatch = xShape.size() == 3;
131+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
132+
newShape[kIdx] *= 2;
133+
134+
// Note: For Intel the dot operands layout's kWidth parameter must match
135+
// the parent's DPAS layout opsPerChannel so we need to materialize a
136+
// new DPAS layout.
137+
if (auto dpasEncoding =
138+
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
139+
auto newDpasEncoding = intel::DpasEncodingAttr::get(
140+
ctx, dpasEncoding.getRepeatCount(),
141+
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
142+
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
143+
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
144+
dpasEncoding.getSubGroupSize());
145+
newVEncoding = DotOperandEncodingAttr::get(
146+
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
147+
} else {
148+
// Figure out the K dimension for the input A/B, given that the return
149+
// type is upcasted A/B type so we need to update the proper dim size.
150+
newVEncoding = DotOperandEncodingAttr::get(
151+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
152+
oldEncoding.getKWidth() * 2);
153+
}
143154
} else {
144-
// Figure out the K dimension for the input A/B, given that the return
145-
// type is upcasted A/B type so we need to update the proper dim size.
146-
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
147-
oldEncoding.getParent(),
148-
oldEncoding.getKWidth() * 2);
155+
auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding);
156+
assert(oldEncoding &&
157+
"Expected a blocked encoding for UpcastMXFP op result.");
158+
newShape.back() *= 2;
159+
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
160+
sizePerThread.back() *= 2;
161+
newVEncoding = BlockedEncodingAttr::get(
162+
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
163+
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
164+
oldEncoding.getCTALayout());
149165
}
150166
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
151167
}

python/test/unit/language/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3441,8 +3441,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
34413441
if mma == 16 and K == 64:
34423442
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
34433443
if is_xpu():
3444-
if "e2m1" in (normal_type, mxfp_type):
3445-
pytest.skip("scaled_dot e2m1 isn't supported on XPU")
3444+
if rhs_scale:
3445+
pytest.skip("scaled_dot with rhs_scale not supported on XPU")
34463446

34473447
@triton.jit
34483448
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/IR/BuiltinTypes.h"
44
#include "triton/Dialect/Triton/IR/Dialect.h"
55
#include "llvm/Support/Casting.h"
6+
#include <iostream>
67

78
namespace mlir::triton::gpu::intel {
89

@@ -150,9 +151,10 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
150151
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
151152
bElemTy.isFloat4E2M1FN())
152153
return DPASEngineType::FP32_FP32_FP8_FP4;
153-
if (aElemTy.isFloat4E2M1FN() && bElemTy.isBF16())
154+
// 2 E2M1 are packed into 1 int8
155+
if (aElemTy.isInteger(8) && bElemTy.isBF16())
154156
return DPASEngineType::FP32_FP32_FP4_BF16;
155-
if (aElemTy.isFloat4E2M1FN() &&
157+
if (aElemTy.isInteger(8) &&
156158
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
157159
return DPASEngineType::FP32_FP32_FP4_FP8;
158160
}

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
252252

253253
private:
254254
bool upcastMXFPUseDotOpEnc =
255-
mlir::triton::tools::getBoolEnv(
256-
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") == 1;
255+
mlir::triton::tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING");
257256

258257
struct OpDescriptor {
259258
TensorValue op;
@@ -294,11 +293,12 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
294293
assert(opDesc.scale && "Expecting valid operand & scale");
295294

296295
unsigned opsPerChannel = dpasEnc.getOpsPerChannel();
297-
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
298-
opsPerChannel *= 2;
299296

300297
MLIRContext *ctx = opDesc.op.getContext();
298+
unsigned rank = retType.getRank();
301299
if (upcastMXFPUseDotOpEnc) {
300+
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
301+
opsPerChannel *= 2;
302302
auto opEncoding = ttg::intel::DpasEncodingAttr::get(
303303
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
304304
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
@@ -313,7 +313,6 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
313313
unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1];
314314
SmallVector<unsigned, 2> threadsPerWarp{instrShapeM,
315315
warpSize / instrShapeM};
316-
unsigned rank = retType.getRank();
317316
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
318317
SmallVector<unsigned, 2> warpsPerCTA(rank, 1);
319318
warpsPerCTA[0] = numWarps;
@@ -334,10 +333,13 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
334333
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
335334
// the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
336335
unsigned scalingBlockSize = 32;
336+
// 2 FP4E2M1 are packed in 1 I8
337337
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
338338
scalingBlockSize = 16;
339+
SmallVector<unsigned, 2> sizePerThread(rank, 1);
340+
sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
339341
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
340-
ctx, {1, scalingBlockSize}, scaleEncoding.getThreadsPerWarp(),
342+
ctx, sizePerThread, scaleEncoding.getThreadsPerWarp(),
341343
scaleEncoding.getWarpsPerCTA(), scaleEncoding.getCTAOrder(),
342344
scaleEncoding.getCTALayout());
343345

0 commit comments

Comments
 (0)