Skip to content

Commit 27f406c

Browse files
authored
Reapply "[NVIDIA] Add native MXFP FP8 scaled_dot for SM120 (#7918)" (#8029) (#8129)
Add a fix for the crash happening on BW. Still doing some perf test
1 parent 5f43194 commit 27f406c

File tree

12 files changed

+846
-140
lines changed

12 files changed

+846
-140
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
52+
InterfaceMethod<
53+
/*desc=*/"Get the output tensor",
54+
/*retType=*/"::mlir::Value",
55+
/*methodName=*/"getD",
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
5358
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5459
/*retType=*/"bool",
5560
/*methodName=*/"verifyDims",
@@ -64,6 +69,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6469
auto aTy = cast<ShapedType>($_op.getA().getType());
6570
auto bTy = cast<ShapedType>($_op.getB().getType());
6671
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72+
auto dTy = cast<ShapedType>($_op.getD().getType());
6773
auto aShape = aTy.getShape();
6874
auto bShape = bTy.getShape();
6975
auto cShape = cTy.getShape();

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
135135
ArrayRef<unsigned> tilesPerWarp,
136136
ArrayRef<unsigned> warpsPerCTA);
137137

138+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
139+
ArrayRef<int64_t> dotOperandShape,
140+
ArrayRef<unsigned> tilesPerWarp,
141+
ArrayRef<unsigned> warpsPerCTA,
142+
unsigned instrM, unsigned instrN,
143+
CTALayoutAttr ctaLayoutAttr);
144+
138145
// Create LinearLayout for nvidia mma tile.
139146
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
140147
unsigned kWidth, ArrayRef<unsigned> order,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,83 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14011401
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
14021402
}
14031403

1404+
// Warp-level block scaling (sm_120, m16n8k32)
1405+
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
1406+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1407+
//
1408+
// Semantics:
1409+
// D = (A * SF_A) * (B * SF_B) + C
1410+
// scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1411+
//
1412+
// Providers (within each warp quad of 4 lanes):
1413+
// - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1414+
// (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1415+
// - B scales are provided by a single lane selected by thread-id-b ∈
1416+
// {0,1,2,3}.
1417+
//
1418+
// Byte selectors (which subfield of the 32-bit metadata is used):
1419+
// - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1420+
//
1421+
// Implementation notes:
1422+
// - We support only scale_vec::1X for now.
1423+
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1424+
// 0)
1425+
// - In this implementation, each lane in a quad has the same scale factor.
1426+
LinearLayout getSM120DotScaledScaleLayout(
1427+
MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t> dotOperandShape,
1428+
ArrayRef<unsigned> tilesPerWarp, ArrayRef<unsigned> warpsPerCTA,
1429+
unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1430+
unsigned rank = dotOperandShape.size();
1431+
auto outDims = standardOutDimNames(ctx, rank);
1432+
1433+
StringAttr kRegister = StringAttr::get(ctx, "register");
1434+
StringAttr kLane = StringAttr::get(ctx, "lane");
1435+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1436+
1437+
const unsigned mIndex = 0;
1438+
const unsigned nIndex = 1;
1439+
const int instrM = mmaInstrM;
1440+
const int instrN = mmaInstrN;
1441+
const int kSize = dotOperandShape[1];
1442+
const int mWarps = warpsPerCTA[mIndex];
1443+
const int nWarps = warpsPerCTA[nIndex];
1444+
const int totalWarps = mWarps * nWarps;
1445+
const unsigned mRep_warp = tilesPerWarp[mIndex];
1446+
const unsigned nRep_warp = tilesPerWarp[nIndex];
1447+
const unsigned kRep = std::min<unsigned>(kSize, 2);
1448+
1449+
std::vector<std::vector<int32_t>> registerBase;
1450+
std::vector<std::vector<int32_t>> laneBase;
1451+
std::vector<std::vector<int32_t>> warpBase;
1452+
if (dotOperandIdx == 0) { // per-row A-scale
1453+
laneBase = {{0, 8}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1454+
for (int offset = instrM * mWarps; offset < instrM * mWarps * mRep_warp;
1455+
offset <<= 1)
1456+
registerBase.push_back({0, offset});
1457+
for (int w = mWarps; w < totalWarps; w <<= 1)
1458+
warpBase.push_back({0, 0});
1459+
for (int offset = instrM; offset < instrM * mWarps; offset <<= 1)
1460+
warpBase.push_back({0, offset});
1461+
} else { // per-col B-scale
1462+
laneBase = {{0, 0}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1463+
if (nRep_warp > 1)
1464+
registerBase.push_back({0, nWarps * instrN});
1465+
for (int k = 1; k < kRep; k += 1)
1466+
registerBase.push_back({1 << (k - 1), 0});
1467+
for (int offset = instrN; offset < instrN * nWarps; offset <<= 1)
1468+
warpBase.push_back({0, offset});
1469+
for (int w = nWarps; w < totalWarps; w <<= 1)
1470+
warpBase.push_back({0, 0});
1471+
}
1472+
1473+
const unsigned kIdx = (dotOperandShape[0] == 1) ? 0 : 1;
1474+
const unsigned mnIdx = 1 - kIdx;
1475+
LinearLayout ctaLayout(
1476+
{{kRegister, registerBase}, {kLane, laneBase}, {kWarp, warpBase}},
1477+
{outDims[kIdx], outDims[mnIdx]});
1478+
return combineCtaCgaWithShape(ctaLayout, ctaLayoutAttr, dotOperandShape);
1479+
}
1480+
14041481
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14051482
ArrayRef<int64_t> dotOperandShape,
14061483
unsigned mfmaMDim,

0 commit comments

Comments
 (0)