Skip to content

Commit 814b862

Browse files
authored
[NVIDIA] Rewrite getSM120DotScaledScaleLayout and Refactor MMAv2 (#8482)
### Context It is split from #8430 and focuses on LinearLayout-related cleanups that should land before introducing FP4 support. ### Changes - The existing implementation `getSM120DotScaledScaleLayout` built the layout from manual bases. This was hard to understand and actually had some bugs or was doing weird things. Rewrote it to use LL helpers like `identity1D` / `zeros1D` together with the direct-sum operator *. It is much clearer and trivially extends to FP4. - `MMAv2.cpp` was also doing weird things, like duplicating the same i8 four times into an i32 rather than packing four distinct i8 values. We now simply sign-extend one i8 into an i32 before every `mma_sync`, and hardcode `byteId` to 0. This, together with the LL change, allowed us to significantly simplify the `MMAv2.cpp` code. We also removed non-obvious uses of hardcoded constants and replaced them with the `NumRegisters` and `BaseOffset` structs. ### Notes - No perf change was made with this PR. - We will follow up with fp4 support for sm_120 shortly.
1 parent 33e7dc2 commit 814b862

File tree

5 files changed

+321
-305
lines changed

5 files changed

+321
-305
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,10 @@ LinearLayout chooseScaledWmmaScaleLayout(
142142
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
143143
ArrayRef<int64_t> dotOperandShape);
144144

145-
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
146-
ArrayRef<int64_t> dotOperandShape,
147-
ArrayRef<unsigned> tilesPerWarp,
145+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
146+
ArrayRef<int64_t> shape, int opIdx,
148147
ArrayRef<unsigned> warpsPerCTA,
149-
unsigned instrM, unsigned instrN,
150-
CTALayoutAttr ctaLayoutAttr);
148+
CTALayoutAttr ctaLayout);
151149

152150
// Create LinearLayout for nvidia mma tile.
153151
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 36 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,81 +1494,55 @@ LinearLayout chooseScaledWmmaScaleLayout(
14941494
return newLL;
14951495
}
14961496

1497-
// Warp-level block scaling (sm_120, m16n8k32)
1498-
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
1499-
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1497+
// PTX ISA - Warp-level MMA Block Scaling
1498+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
15001499
//
1501-
// Semantics:
1502-
// D = (A * SF_A) * (B * SF_B) + C
1503-
// scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1500+
// This function generates layouts for scale tensors used in scaled dot
1501+
// operations.
15041502
//
1505-
// Providers (within each warp quad of 4 lanes):
1506-
// - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1507-
// (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1508-
// - B scales are provided by a single lane selected by thread-id-b ∈
1509-
// {0,1,2,3}.
1510-
//
1511-
// Byte selectors (which subfield of the 32-bit metadata is used):
1512-
// - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1503+
// Supported .kind x scale_vec_size:
1504+
// mxf8f6f4 with UE8M0 scales -> .scale_vec::1X
15131505
//
15141506
// Implementation notes:
15151507
// - We support only scale_vec::1X for now.
15161508
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
15171509
// 0)
1518-
// - In this implementation, each lane in a quad has the same scale factor.
1519-
LinearLayout getSM120DotScaledScaleLayout(
1520-
MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t> dotOperandShape,
1521-
ArrayRef<unsigned> tilesPerWarp, ArrayRef<unsigned> warpsPerCTA,
1522-
unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1523-
unsigned rank = dotOperandShape.size();
1510+
// - We choose a fixed byte selector for A (byte-id-a = 0) and B (byte-id-b =
1511+
// 0)
1512+
// - Each lane in a quad has the same scale factor.
1513+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
1514+
ArrayRef<int64_t> shape, int opIdx,
1515+
ArrayRef<unsigned> warpsPerCTA,
1516+
CTALayoutAttr ctaLayout) {
1517+
unsigned rank = shape.size();
15241518
auto outDims = standardOutDimNames(ctx, rank);
1525-
15261519
StringAttr kRegister = StringAttr::get(ctx, "register");
15271520
StringAttr kLane = StringAttr::get(ctx, "lane");
15281521
StringAttr kWarp = StringAttr::get(ctx, "warp");
1522+
// - A: [M, K]
1523+
// - B: [K, N]
1524+
// - aScale: [M, K / K_GROUP_SIZE]
1525+
// - bScale: [N, K / K_GROUP_SIZE]
1526+
const unsigned kIdx = 1;
1527+
const unsigned mnIdx = 0;
15291528

1530-
const unsigned mIndex = 0;
1531-
const unsigned nIndex = 1;
1532-
const int instrM = mmaInstrM;
1533-
const int instrN = mmaInstrN;
1534-
const int kSize = dotOperandShape[1];
1535-
const int mWarps = warpsPerCTA[mIndex];
1536-
const int nWarps = warpsPerCTA[nIndex];
1537-
const int totalWarps = mWarps * nWarps;
1538-
const unsigned mRep_warp = tilesPerWarp[mIndex];
1539-
const unsigned nRep_warp = tilesPerWarp[nIndex];
1540-
const unsigned kRep = std::min<unsigned>(kSize, 2);
1541-
1542-
std::vector<std::vector<int32_t>> registerBase;
15431529
std::vector<std::vector<int32_t>> laneBase;
1544-
std::vector<std::vector<int32_t>> warpBase;
1545-
if (dotOperandIdx == 0) { // per-row A-scale
1546-
laneBase = {{0, 8}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1547-
for (int offset = instrM * mWarps; offset < instrM * mWarps * mRep_warp;
1548-
offset <<= 1)
1549-
registerBase.push_back({0, offset});
1550-
for (int w = mWarps; w < totalWarps; w <<= 1)
1551-
warpBase.push_back({0, 0});
1552-
for (int offset = instrM; offset < instrM * mWarps; offset <<= 1)
1553-
warpBase.push_back({0, offset});
1554-
} else { // per-col B-scale
1555-
laneBase = {{0, 0}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1556-
if (nRep_warp > 1)
1557-
registerBase.push_back({0, nWarps * instrN});
1558-
for (int k = 1; k < kRep; k += 1)
1559-
registerBase.push_back({1 << (k - 1), 0});
1560-
for (int offset = instrN; offset < instrN * nWarps; offset <<= 1)
1561-
warpBase.push_back({0, offset});
1562-
for (int w = nWarps; w < totalWarps; w <<= 1)
1563-
warpBase.push_back({0, 0});
1564-
}
1565-
1566-
const unsigned kIdx = (dotOperandShape[0] == 1) ? 0 : 1;
1567-
const unsigned mnIdx = 1 - kIdx;
1568-
LinearLayout ctaLayout(
1569-
{{kRegister, registerBase}, {kLane, laneBase}, {kWarp, warpBase}},
1570-
{outDims[kIdx], outDims[mnIdx]});
1571-
return combineCtaCgaWithShape(ctaLayout, ctaLayoutAttr, dotOperandShape);
1530+
SmallVector<unsigned> order;
1531+
SmallVector<unsigned> mmaWarpsPerCTA;
1532+
if (opIdx == 0) {
1533+
laneBase = {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}};
1534+
order = SmallVector<unsigned>{1u, 0u};
1535+
mmaWarpsPerCTA = SmallVector<unsigned>{warpsPerCTA[0], warpsPerCTA[1]};
1536+
} else {
1537+
laneBase = {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}};
1538+
order = SmallVector<unsigned>{0u, 1u};
1539+
mmaWarpsPerCTA = SmallVector<unsigned>{warpsPerCTA[1], warpsPerCTA[0]};
1540+
}
1541+
LinearLayout LL =
1542+
LinearLayout::identity1D(shape[1], kRegister, outDims[kIdx]) *
1543+
LinearLayout({{kLane, laneBase}}, {outDims[mnIdx], outDims[kIdx]}) *
1544+
broadcastedDotOperandLayout(ctx, mmaWarpsPerCTA, order, 1u, kWarp);
1545+
return combineCtaCgaWithShape(LL, ctaLayout, shape);
15721546
}
15731547

15741548
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,25 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
692692
mlir::isa<LinearEncodingAttr>(bScaleType.getEncoding())) {
693693
return failure();
694694
}
695+
auto aElemType = dotOp.getAElemType();
696+
auto bElemType = dotOp.getBElemType();
697+
auto isFP8 = [&](ScaleDotElemType elemType) -> bool {
698+
return elemType == ScaleDotElemType::E4M3 ||
699+
elemType == ScaleDotElemType::E5M2;
700+
};
701+
auto isFP4 = [&](ScaleDotElemType elemType) -> bool {
702+
return elemType == ScaleDotElemType::E2M1;
703+
};
704+
// mixed precision is not supported
705+
if (isFP8(aElemType) && isFP4(bElemType) ||
706+
isFP4(aElemType) && isFP8(bElemType)) {
707+
return failure();
708+
}
709+
710+
auto scaleElemType = dotOp.getAScale().getType().getElementType();
711+
if (scaleElemType != dotOp.getBScale().getType().getElementType()) {
712+
return failure();
713+
}
695714

696715
// Common MMA encoding creation
697716
auto mmaResult =
@@ -738,23 +757,18 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
738757
return rep.size() >= 3 ? rep[2] : 1;
739758
}
740759
};
741-
SmallVector<unsigned, 2> tilesPerWarp{computeTilePerWarp(newA, 0),
742-
computeTilePerWarp(newB, 1)};
760+
761+
const auto mmaWarps = mmaResult.mmaEnc.getWarpsPerCTA(); // [wM, wN]
762+
743763
// Convert scales to Linear layout
744764
auto convertScale = [&](Value scale, int opIdx) -> Value {
745-
if (!scale)
746-
return Value();
747765
auto ty = cast<RankedTensorType>(scale.getType());
748766
SmallVector<int64_t> shape = llvm::to_vector(ty.getShape());
749767
MLIRContext *ctx = ty.getContext();
750-
const auto mmaWarps = mmaResult.mmaEnc.getWarpsPerCTA(); // [wM, wN]
751-
const auto instr = mmaResult.mmaEnc.getInstrShape(); // [instrM, instrN]
752-
const unsigned instrM = instr[0], instrN = instr[1];
753-
754768
auto blocked = cast<triton::gpu::BlockedEncodingAttr>(ty.getEncoding());
769+
755770
auto ll = triton::gpu::getSM120DotScaledScaleLayout(
756-
ctx, opIdx, shape, tilesPerWarp,
757-
/*warpsPerCTA=*/mmaWarps, instrM, instrN, blocked.getCTALayout());
771+
ctx, shape, opIdx, mmaWarps, blocked.getCTALayout());
758772
auto newEnc = triton::gpu::LinearEncodingAttr::get(ctx, ll);
759773
auto newTy = RankedTensorType::get(shape, ty.getElementType(), newEnc);
760774
return rewriter.create<ConvertLayoutOp>(scale.getLoc(), newTy, scale);

0 commit comments

Comments
 (0)