Skip to content

Commit 915c149

Browse files
authored
[AMD] Enable B scale for scaled_dot (#5112)
This commit enables supporting B scale directly in AccelerateMatmul and UpcastMXFPOp patterns. Along the way we need to update the verification for the UpcastMXFPOp to make sure it allows the case.
1 parent d556ce9 commit 915c149

File tree

4 files changed

+93
-66
lines changed

4 files changed

+93
-66
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#include "mlir/IR/BuiltinTypes.h"
2-
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
32
#include "triton/Dialect/Triton/IR/Dialect.h"
43
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
54
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6-
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
7-
#include "llvm/Support/raw_ostream.h"
85

96
#define GET_OP_CLASSES
107
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
@@ -39,19 +36,6 @@ LogicalResult UpcastMXFPOp::verify() {
3936
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
4037
}
4138

42-
// Change to support fp8 types
43-
const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
44-
45-
if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
46-
return emitOpError("last dimension of first operand must be 16 times "
47-
"larger than that of the second operand");
48-
}
49-
50-
if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) {
51-
return emitOpError(
52-
"all dimensions except the last must match between operands");
53-
}
54-
5539
auto layoutX = xTy.getEncoding();
5640
auto layoutScale = scaleTy.getEncoding();
5741
if (bool(layoutX) != bool(layoutScale)) {
@@ -82,6 +66,28 @@ LogicalResult UpcastMXFPOp::verify() {
8266
}
8367
}
8468

69+
// Change to support fp8 types
70+
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
71+
// Figure out the K dimension for the input A/B. For A/B scale, the K
72+
// dimension is always the last dimension.
73+
const int opIdx = dotEncoding.getOpIdx();
74+
const bool hasBatch = xShape.size() == 3;
75+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
76+
77+
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
78+
return emitOpError("K dimension of first operand must be 16 times "
79+
"larger than last/K dimension of the second operand");
80+
}
81+
82+
// Check other dimensions match too. For input A/B, we need to figure out the
83+
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
84+
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
85+
if (hasBatch && xShape[0] != scaleShape[0])
86+
return emitOpError("batch dimension must match between operands");
87+
if (xShape[mnIdx] != scaleShape[hasBatch]) {
88+
return emitOpError("M/N dimension must match between operands");
89+
}
90+
8591
return success();
8692
}
8793

@@ -100,14 +106,20 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
100106
RankedTensorType retTy;
101107

102108
auto newShape = SmallVector<int64_t>(xShape);
103-
newShape.back() *= 2;
104109
if (!encoding) {
110+
newShape.back() *= 2;
105111
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
106112
} else {
107113
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
108114
auto newVEncoding = DotOperandEncodingAttr::get(
109115
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
110116
oldEncoding.getKWidth() * 2);
117+
// Figure out the K dimension for the input A/B, given that the return
118+
// type is upcasted A/B type so we need to update the proper dim size.
119+
const int opIdx = oldEncoding.getOpIdx();
120+
const bool hasBatch = xShape.size() == 3;
121+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
122+
newShape[kIdx] *= 2;
111123
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
112124
newVEncoding);
113125
}

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3386,8 +3386,6 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
33863386
if cc < (8, 9):
33873387
pytest.skip("float8e4nv not supported on CUDA < 8.9")
33883388
if is_hip():
3389-
if rhs_scale:
3390-
pytest.skip("scales on rhs not yet support for HIP")
33913389
if not is_hip_cdna():
33923390
pytest.skip("scaled_dot only implemented for HIP CDNA")
33933391
if "e4m3" in (normal_type, mxfp_type) and not is_hip_mi300():

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
5353

5454
auto dotEncoding =
5555
cast<DotOperandEncodingAttr>(op.getSrc().getType().getEncoding());
56-
if (dotEncoding.getOpIdx() == 1)
57-
return rewriter.notifyMatchFailure(op, "NYI: dot RHS");
5856
auto mfmaEncoding = dyn_cast<AMDMfmaEncodingAttr>(dotEncoding.getParent());
5957
if (!mfmaEncoding)
6058
return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand");

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -497,24 +497,27 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
497497
if (!isa_and_nonnull<BlockedEncodingAttr>(oldRetType.getEncoding()))
498498
return rewriter.notifyMatchFailure(
499499
dotOp, "expected blocked encoding result tensor");
500-
501-
if (dotOp.getRhsScale())
502-
return rewriter.notifyMatchFailure(dotOp, "NYI: RHS scale");
500+
unsigned rank = oldRetType.getRank();
501+
if (rank == 3)
502+
return rewriter.notifyMatchFailure(dotOp, "NYI: 3d case");
503503

504504
TensorValue a = dotOp.getLhs();
505505
TensorValue b = dotOp.getRhs();
506506
TensorValue aScale = dotOp.getLhsScale();
507+
TensorValue bScale = dotOp.getRhsScale();
508+
if (aScale && bScale)
509+
return rewriter.notifyMatchFailure(dotOp, "NYI: both LHS and RHS scale");
510+
507511
ScaleDotElemType aElemType = dotOp.getLhsType();
508512
ScaleDotElemType bElemType = dotOp.getRhsType();
509-
510-
if (!(aElemType == ScaleDotElemType::E2M1 ||
511-
aElemType == ScaleDotElemType::E4M3 ||
512-
aElemType == ScaleDotElemType::E5M2))
513-
return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8/mxfp4 LHS");
514-
if (!(bElemType == ScaleDotElemType::E4M3 ||
515-
bElemType == ScaleDotElemType::E5M2 ||
516-
bElemType == ScaleDotElemType::BF16))
517-
return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8/bf16 RHS");
513+
auto supportsTypes = [](ScaleDotElemType elemType) {
514+
return elemType == ScaleDotElemType::E2M1 ||
515+
elemType == ScaleDotElemType::E4M3 ||
516+
elemType == ScaleDotElemType::E5M2 ||
517+
elemType == ScaleDotElemType::BF16;
518+
};
519+
if (!supportsTypes(aElemType) || !supportsTypes(bElemType))
520+
return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6 operand");
518521

519522
MLIRContext *ctx = dotOp.getContext();
520523
auto moduleOp = dotOp->getParentOfType<ModuleOp>();
@@ -534,27 +537,30 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
534537
unsigned kDim = mfmaInstr.value().getKDim();
535538
unsigned kBase = mfmaInstr.value().getKBase();
536539

537-
// If A tensor contains mxfp4, we pack every two values into one int8 value
538-
// there. For such cases, we have different initial kWidth for LHS and RHS,
539-
// which will be "fixed" later by using upcast_mxfp to convert LHS to
540-
// unpacked values. For such packed cases, we cannot support flexible kPack
541-
// choices from the developer--it just does not apply here. So mandate the
542-
// choice here.
543-
bool isPacked = aElemType == ScaleDotElemType::E2M1;
544-
unsigned kWdiths[] = {isPacked ? 4 : kBase * kPack,
545-
isPacked ? 8 : kBase * kPack};
546-
547-
// For A tensor, 32 consecutive elements along K dim share the same scale.
540+
// For mxfp4 A/B tensor, we pack every two values into one int8 value there.
541+
// For such cases, we have different initial kWidth for LHS and RHS, which
542+
// will be "fixed" later by using upcast_mxfp to convert LHS to unpacked
543+
// values. For such packed cases, we cannot support flexible kPack choices
544+
// from the developer--it just does not apply here. So mandate the choice
545+
// here.
546+
bool isAPacked = aElemType == ScaleDotElemType::E2M1;
547+
bool isBPacked = bElemType == ScaleDotElemType::E2M1;
548+
bool isPacked = isAPacked || isBPacked;
549+
unsigned kWdiths[] = {isPacked ? (isAPacked ? 4 : 8) : kBase * kPack,
550+
isPacked ? (isAPacked ? 8 : 4) : kBase * kPack};
551+
552+
// For A/B tensor, 32 consecutive elements along K dim share the same scale.
548553
// We'd like to keep the scale values together with the base values in the
549554
// same warp to avoid cross-warp data exchange. It means we want warpsPerCTA
550-
// = 1 along the N dimension.
551-
SmallVector<unsigned, 3> warpsPerCTA(oldRetType.getRank(), 1);
552-
warpsPerCTA.front() = numWarps;
555+
// = 1 along the N/M dimension for the mxfp A/B case. We achieve that by
556+
// setting the M/N dimension as numWarps.
557+
SmallVector<unsigned, 2> mfmaWarpsPerCTA(rank, 1);
558+
mfmaWarpsPerCTA[aScale ? 0 : 1] = numWarps;
553559

554560
// Always use transposed mfma layout. This enables larger vectorization
555561
// for global store instructions.
556562
auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
557-
ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, warpsPerCTA,
563+
ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, mfmaWarpsPerCTA,
558564
/*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout);
559565

560566
auto newRetType = RankedTensorType::get(
@@ -571,11 +577,9 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
571577
auto newVType = RankedTensorType::get(
572578
vType.getShape(), vType.getElementType(), newVEncoding);
573579
v = rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
574-
if (type == ScaleDotElemType::BF16)
575-
return v;
576-
// Don't need to covert int8 holding mxfp4 for A--the upcast_mxfp op can
580+
// Don't need to covert int8 holding mxfp4--the upcast_mxfp op can
577581
// take int8 tensor as input.
578-
if (idx == 0 && type == ScaleDotElemType::E2M1)
582+
if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::E2M1)
579583
return v;
580584

581585
auto vTypeBf16 = RankedTensorType::get(
@@ -586,27 +590,42 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
586590
a = toMMABf16(a, 0, aElemType);
587591
b = toMMABf16(b, 1, bElemType);
588592

589-
// We need to have "matching" encoding between the A tensor and A scale
593+
// We need to have "matching" encoding between the main tensor and scale
590594
// tensor to make sure the scale values needed is in the same warp. So we
591595
// adopt the same CTA layout and warps per CTA. The warp dimensions needs to
592-
// match along M dimension too. With in a warp, we have 64 threads. We let
593-
// each thread read in one scale value. So we need a threadsPerWarp = mDim
594-
// along M dimension.
596+
// match along M/N dimension too. With in a warp, we have 64 threads. We let
597+
// each thread read in one scale value. So we need a threadsPerWarp =
598+
// mDim/nDim along M/N dimension. Note that For MFMA intrinsics, mDim is
599+
// always the same as nDim. And for scaled dot scale tensor, we always have
600+
// K as the innermost dimension. So we have the same threadsPerWarp in the
601+
// below no matter A or B scale. Similarly for warpsPerCTA, the non-K
602+
// dimension is always at index 0.
603+
assert(mDim == nDim);
595604
SmallVector<unsigned, 2> threadsPerWarp = {mDim, numThreads / mDim};
605+
SmallVector<unsigned, 2> blockWarpsPerCTA(rank, 1);
606+
blockWarpsPerCTA[0] = numWarps;
596607
auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
597-
ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, ctaLayout);
608+
ctx, {1, 1}, threadsPerWarp, blockWarpsPerCTA, {1, 0}, ctaLayout);
609+
610+
auto upcastMXFP = [&](TensorValue main, TensorValue scale,
611+
ScaleDotElemType elemType) -> Value {
612+
if (!scale)
613+
return main;
598614

599-
auto newScaleType = RankedTensorType::get(aScale.getType().getShape(),
600-
aScale.getType().getElementType(),
601-
newScaleEncoding);
602-
aScale = rewriter.create<ttg::ConvertLayoutOp>(aScale.getLoc(),
603-
newScaleType, aScale);
615+
auto newScaleType = RankedTensorType::get(
616+
scale.getType().getShape(), scale.getType().getElementType(),
617+
newScaleEncoding);
618+
auto convOp = rewriter.create<ttg::ConvertLayoutOp>(scale.getLoc(),
619+
newScaleType, scale);
604620

605-
auto scaledA = rewriter.create<triton::gpu::UpcastMXFPOp>(
606-
dotOp.getLoc(), a, aScale, dotOp.getLhsType());
621+
return rewriter.create<triton::gpu::UpcastMXFPOp>(dotOp.getLoc(), main,
622+
convOp, elemType);
623+
};
607624

608-
auto newDot =
609-
rewriter.create<DotOp>(dotOp.getLoc(), newRetType, scaledA, b, newAcc);
625+
Value scaledA = upcastMXFP(a, aScale, dotOp.getLhsType());
626+
Value scaledB = upcastMXFP(b, bScale, dotOp.getRhsType());
627+
auto newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, scaledA,
628+
scaledB, newAcc);
610629
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
611630
newDot);
612631
return success();

0 commit comments

Comments
 (0)