Skip to content

Commit 90f937a

Browse files
committed
Fix rhs scaling
1 parent 698a0bd commit 90f937a

File tree

5 files changed

+57
-50
lines changed

5 files changed

+57
-50
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
152152
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
153153
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
154154
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
155-
newShape.back() *= 2;
156155
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
157-
sizePerThread.back() *= 2;
156+
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
157+
sizePerThread[!opIdx] *= 2;
158+
newShape[!opIdx] *= 2;
158159
newVEncoding = BlockedEncodingAttr::get(
159160
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
160161
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),

python/test/unit/language/test_core.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3440,9 +3440,6 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
34403440
pytest.skip(f"scaled_dot({normal_type}, {mxfp_type}) only implemented for MI300")
34413441
if mma == 16 and K == 64:
34423442
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
3443-
if is_xpu():
3444-
if rhs_scale:
3445-
pytest.skip("scaled_dot with rhs_scale not supported on XPU")
34463443

34473444
@triton.jit
34483445
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,

third_party/intel/include/Analysis/DPAS.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class DPASAnalysis {
2424
FP32_FP32_TF32_TF32,
2525
FP16_FP16_FP16_FP16,
2626
BF16_BF16_BF16_BF16,
27+
U32_U32_U8_U8,
28+
S32_S32_S8_S8,
2729
// data types for dot scaled.
2830
FP32_FP32_BF16_FP8,
2931
FP32_FP32_BF16_FP4,
@@ -32,8 +34,6 @@ class DPASAnalysis {
3234
FP32_FP32_FP8_FP4,
3335
FP32_FP32_FP4_BF16,
3436
FP32_FP32_FP4_FP8,
35-
U32_U32_U8_U8,
36-
S32_S32_S8_S8,
3737
NOT_APPLICABLE
3838
};
3939

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
140140
if (aElemTy.isBF16() &&
141141
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
142142
return DPASEngineType::FP32_FP32_BF16_FP8;
143-
if (aElemTy.isBF16() && bElemTy.isFloat4E2M1FN())
143+
// 2 E2M1 are packed into 1 int8
144+
if (aElemTy.isBF16() && bElemTy.isInteger(8))
144145
return DPASEngineType::FP32_FP32_BF16_FP4;
145146
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
146147
bElemTy.isBF16())
@@ -149,9 +150,8 @@ DPASAnalysis::DPASEngineType DPASAnalysis::getDPASType(Operation *op) {
149150
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
150151
return DPASEngineType::FP32_FP32_FP8_FP8;
151152
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
152-
bElemTy.isFloat4E2M1FN())
153+
bElemTy.isInteger(8))
153154
return DPASEngineType::FP32_FP32_FP8_FP4;
154-
// 2 E2M1 are packed into 1 int8
155155
if (aElemTy.isInteger(8) && bElemTy.isBF16())
156156
return DPASEngineType::FP32_FP32_FP4_BF16;
157157
if (aElemTy.isInteger(8) &&

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,12 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
291291
static_assert(opIdx == 0 || opIdx == 1, "Illegal operand index");
292292
assert(opDesc.scale && "Expecting valid operand & scale");
293293

294-
unsigned opsPerChannel = dpasEnc.getOpsPerChannel();
295-
296294
MLIRContext *ctx = opDesc.op.getContext();
295+
unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
296+
unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
297+
unsigned opsPerChannel = dpasEnc.getOpsPerChannel();
297298
unsigned rank = retType.getRank();
299+
298300
if (upcastMXFPUseDotOpEnc) {
299301
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
300302
opsPerChannel *= 2;
@@ -312,7 +314,6 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
312314
unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[1];
313315
SmallVector<unsigned, 2> threadsPerWarp{instrShapeM,
314316
warpSize / instrShapeM};
315-
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
316317
SmallVector<unsigned, 2> warpsPerCTA(rank, 1);
317318
warpsPerCTA[0] = numWarps;
318319
auto CTALayout = ttg::getCTALayout(retType.getEncoding());
@@ -323,44 +324,52 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
323324
TensorValue scale = createScale(opDesc.scale, newScaleEncoding, rewriter);
324325

325326
return createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
326-
} else {
327-
auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
328-
opDesc.scale.getType().getEncoding());
329-
assert(scaleEncoding && "Expecting blocked encoding for scale");
330-
331-
// Referring to
332-
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
333-
// the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
334-
unsigned scalingBlockSize = 32;
335-
// 2 FP4E2M1 are packed in 1 I8
336-
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
337-
scalingBlockSize = 16;
338-
SmallVector<unsigned, 2> sizePerThread(rank, 1);
339-
sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
340-
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
341-
ctx, sizePerThread, scaleEncoding.getThreadsPerWarp(),
342-
scaleEncoding.getWarpsPerCTA(), scaleEncoding.getCTAOrder(),
343-
scaleEncoding.getCTALayout());
344-
345-
TensorValue op =
346-
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
347-
TensorValue scale = opDesc.scale;
348-
349-
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
350-
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
351-
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
352-
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
353-
auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(
354-
ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel());
355-
356-
auto upcastOp = createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
357-
358-
auto retType = cast<RankedTensorType>(upcastOp.getType());
359-
retType = RankedTensorType::get(
360-
retType.getShape(), retType.getElementType(), retDotOpEncoding);
361-
return rewriter.create<ttg::ConvertLayoutOp>(opDesc.op.getLoc(), retType,
362-
upcastOp);
363327
}
328+
329+
auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
330+
opDesc.scale.getType().getEncoding());
331+
assert(scaleEncoding && "Expecting blocked encoding for scale");
332+
333+
// Referring to
334+
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
335+
// the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
336+
unsigned scalingBlockSize = 32;
337+
// 2 FP4E2M1 are packed in 1 I8
338+
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
339+
scalingBlockSize = 16;
340+
SmallVector<unsigned> sizePerThread = {1, 1};
341+
SmallVector<unsigned> threadsPerWarp = {1, 1};
342+
sizePerThread[!opIdx] = scalingBlockSize;
343+
threadsPerWarp[opIdx] = warpSize;
344+
SmallVector<unsigned> warpsPerCTA = {numWarps, 1};
345+
346+
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
347+
ctx, sizePerThread, threadsPerWarp, warpsPerCTA,
348+
scaleEncoding.getCTAOrder(), scaleEncoding.getCTALayout());
349+
TensorValue op =
350+
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
351+
352+
warpsPerCTA = opIdx ? SmallVector<unsigned>{1, numWarps}
353+
: SmallVector<unsigned>{numWarps, 1};
354+
auto newScaleEncoding = ttg::BlockedEncodingAttr::get(
355+
ctx, {1, 1}, {warpSize, 1}, warpsPerCTA, scaleEncoding.getCTAOrder(),
356+
scaleEncoding.getCTALayout());
357+
TensorValue scale = createScale(opDesc.scale, newScaleEncoding, rewriter);
358+
359+
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
360+
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
361+
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
362+
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
363+
auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(
364+
ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel());
365+
366+
auto upcastOp = createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
367+
368+
auto resultType = cast<RankedTensorType>(upcastOp.getType());
369+
resultType = RankedTensorType::get(
370+
resultType.getShape(), resultType.getElementType(), retDotOpEncoding);
371+
return rewriter.create<ttg::ConvertLayoutOp>(opDesc.op.getLoc(), resultType,
372+
upcastOp);
364373
}
365374

366375
template <unsigned opIdx>

0 commit comments

Comments
 (0)