Skip to content

Commit 6ec5e0c

Browse files
authored
Revert "[NVIDIA] Add native MXFP FP8 scaled_dot for SM120 (#7918)" (#8029)
This reverts commit 001ec4b: @ita9naiwa sorry I need to revert this PR, this causes both functional and performance regressions on our nightly internal workloads. I'll isolate some reproducer for you to be able to debug and fix those
1 parent 1c2e9bb commit 6ec5e0c

File tree

12 files changed

+141
-853
lines changed

12 files changed

+141
-853
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
53-
/*desc=*/"Get the output tensor",
54-
/*retType=*/"::mlir::Value",
55-
/*methodName=*/"getD",
56-
/*args=*/(ins)>,
57-
InterfaceMethod<
52+
InterfaceMethod<
5853
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5954
/*retType=*/"bool",
6055
/*methodName=*/"verifyDims",
@@ -69,7 +64,6 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6964
auto aTy = cast<ShapedType>($_op.getA().getType());
7065
auto bTy = cast<ShapedType>($_op.getB().getType());
7166
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72-
auto dTy = cast<ShapedType>($_op.getD().getType());
7367
auto aShape = aTy.getShape();
7468
auto bShape = bTy.getShape();
7569
auto cShape = cTy.getShape();

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class AMDRotatingSharedEncodingAttr;
1919
class AMDMfmaEncodingAttr;
2020
class TensorOrMemDesc;
2121
class MemDescType;
22-
class CTALayoutAttr;
2322

2423
// - BlockedEncodingAttrs have the following input dimensions.
2524
//
@@ -127,13 +126,6 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
127126
ArrayRef<unsigned> tilesPerWarp,
128127
ArrayRef<unsigned> warpsPerCTA);
129128

130-
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
131-
ArrayRef<int64_t> dotOperandShape,
132-
ArrayRef<unsigned> tilesPerWarp,
133-
ArrayRef<unsigned> warpsPerCTA,
134-
unsigned instrM, unsigned instrN,
135-
CTALayoutAttr ctaLayoutAttr);
136-
137129
// Create LinearLayout for nvidia mma tile.
138130
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
139131
unsigned kWidth, ArrayRef<unsigned> order,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,83 +1403,6 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14031403
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
14041404
}
14051405

1406-
// Warp-level block scaling (sm_120, m16n8k32)
1407-
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
1408-
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1409-
//
1410-
// Semantics:
1411-
// D = (A * SF_A) * (B * SF_B) + C
1412-
// scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1413-
//
1414-
// Providers (within each warp quad of 4 lanes):
1415-
// - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1416-
// (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1417-
// - B scales are provided by a single lane selected by thread-id-b ∈
1418-
// {0,1,2,3}.
1419-
//
1420-
// Byte selectors (which subfield of the 32-bit metadata is used):
1421-
// - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1422-
//
1423-
// Implementation notes:
1424-
// - We support only scale_vec::1X for now.
1425-
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1426-
// 0)
1427-
// - In this implementation, each lane in a quad has the same scale factor.
1428-
LinearLayout getSM120DotScaledScaleLayout(
1429-
MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t> dotOperandShape,
1430-
ArrayRef<unsigned> tilesPerWarp, ArrayRef<unsigned> warpsPerCTA,
1431-
unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1432-
unsigned rank = dotOperandShape.size();
1433-
auto outDims = standardOutDimNames(ctx, rank);
1434-
1435-
StringAttr kRegister = StringAttr::get(ctx, "register");
1436-
StringAttr kLane = StringAttr::get(ctx, "lane");
1437-
StringAttr kWarp = StringAttr::get(ctx, "warp");
1438-
1439-
const unsigned mIndex = 0;
1440-
const unsigned nIndex = 1;
1441-
const int instrM = mmaInstrM;
1442-
const int instrN = mmaInstrN;
1443-
const int kSize = dotOperandShape[1];
1444-
const int mWarps = warpsPerCTA[mIndex];
1445-
const int nWarps = warpsPerCTA[nIndex];
1446-
const int totalWarps = mWarps * nWarps;
1447-
const unsigned mRep_warp = tilesPerWarp[mIndex];
1448-
const unsigned nRep_warp = tilesPerWarp[nIndex];
1449-
const unsigned kRep = std::min<unsigned>(kSize, 2);
1450-
1451-
std::vector<std::vector<int32_t>> registerBase;
1452-
std::vector<std::vector<int32_t>> laneBase;
1453-
std::vector<std::vector<int32_t>> warpBase;
1454-
if (dotOperandIdx == 0) { // per-row A-scale
1455-
laneBase = {{0, 8}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1456-
for (int offset = instrM * mWarps; offset < instrM * mWarps * mRep_warp;
1457-
offset <<= 1)
1458-
registerBase.push_back({0, offset});
1459-
for (int w = mWarps; w < totalWarps; w <<= 1)
1460-
warpBase.push_back({0, 0});
1461-
for (int offset = instrM; offset < instrM * mWarps; offset <<= 1)
1462-
warpBase.push_back({0, offset});
1463-
} else { // per-col B-scale
1464-
laneBase = {{0, 0}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1465-
if (nRep_warp > 1)
1466-
registerBase.push_back({0, nWarps * instrN});
1467-
for (int k = 1; k < kRep; k += 1)
1468-
registerBase.push_back({1 << (k - 1), 0});
1469-
for (int offset = instrN; offset < instrN * nWarps; offset <<= 1)
1470-
warpBase.push_back({0, offset});
1471-
for (int w = nWarps; w < totalWarps; w <<= 1)
1472-
warpBase.push_back({0, 0});
1473-
}
1474-
1475-
const unsigned kIdx = (dotOperandShape[0] == 1) ? 0 : 1;
1476-
const unsigned mnIdx = 1 - kIdx;
1477-
LinearLayout ctaLayout(
1478-
{{kRegister, registerBase}, {kLane, laneBase}, {kWarp, warpBase}},
1479-
{outDims[kIdx], outDims[mnIdx]});
1480-
return combineCtaCgaWithShape(ctaLayout, ctaLayoutAttr, dotOperandShape);
1481-
}
1482-
14831406
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14841407
ArrayRef<int64_t> dotOperandShape,
14851408
unsigned mfmaMDim,

0 commit comments

Comments
 (0)