Skip to content

Commit b220381

Browse files
knwngloislo
authored andcommitted
[AMD] Support dot_scaled(mxfp8, mxfp4) for gfx950 (triton-lang#5985)
This PR supported following cases in dot_scaled: - mxfp8(both fp8 and bf8) x mxfp8 - mxfp8 x mxfp4 in any order - scale of either or both operands can be None
1 parent 268224f commit b220381

File tree

7 files changed

+481
-199
lines changed

7 files changed

+481
-199
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
#include "triton/Tools/LinearLayout.h"
1010

11+
namespace mlir::triton {
12+
enum class ScaleDotElemType : uint32_t;
13+
} // namespace mlir::triton
14+
1115
namespace mlir::triton::gpu {
1216
class SwizzledSharedEncodingAttr;
1317
class NVMMASharedEncodingAttr;
18+
class AMDMfmaEncodingAttr;
1419

1520
// - BlockedEncodingAttrs have the following input dimensions.
1621
//
@@ -261,6 +266,20 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
261266
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
262267
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
263268
int32_t elemBitWidth);
269+
270+
// Create LinearLayout for mxfp4 and mxfp8 operand in scaled mfma.
271+
// For mxfp4, we use dot layout directly. Mxfp8 is not covered by dot
272+
// layout, so we need to manually create linear layout for it.
273+
LinearLayout
274+
chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth,
275+
int dotOperandIdx, ScaleDotElemType elemType,
276+
llvm::ArrayRef<int64_t> dotOperandShape);
277+
278+
// Create LinearLayout for scale in scaled mfma.
279+
LinearLayout chooseScaledMfmaScaleLayout(
280+
MLIRContext *ctx, int dotOperandIdx,
281+
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
282+
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
264283
} // namespace mlir::triton::gpu
265284

266285
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <vector>
22

3+
#include "triton/Dialect/Triton/IR/Dialect.h"
34
#include "triton/Dialect/Triton/IR/Utility.h"
45
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
56
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -14,6 +15,8 @@
1415
#include "llvm/Support/ErrorHandling.h"
1516
#include "llvm/Support/MathExtras.h"
1617

18+
using mlir::triton::ScaleDotElemType;
19+
1720
namespace mlir::triton::gpu {
1821
namespace {
1922

@@ -1335,4 +1338,154 @@ LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
13351338
return chooseDotDsReadB64Tr16Layout(dot, shape, elemBitWidth);
13361339
}
13371340

1341+
LinearLayout
1342+
chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth,
1343+
int dotOperandIdx, ScaleDotElemType elemType,
1344+
llvm::ArrayRef<int64_t> dotOperandShape) {
1345+
MLIRContext *ctx = mfmaEnc.getContext();
1346+
unsigned mDim = mfmaEnc.getMDim();
1347+
if (elemType == ScaleDotElemType::E2M1) {
1348+
auto newEncoding =
1349+
DotOperandEncodingAttr::get(ctx, dotOperandIdx, mfmaEnc, kWidth / 2);
1350+
return newEncoding.toLinearLayout(dotOperandShape);
1351+
}
1352+
1353+
// For mxfp8, each lane contains 32 elements, consisting of two blocks
1354+
// of 16 consecutive elements. There's a gap between these two blocks,
1355+
// which is not supported by normal dot layout.
1356+
assert(elemType == ScaleDotElemType::E4M3 ||
1357+
elemType == ScaleDotElemType::E5M2);
1358+
using basisT = std::vector<std::vector<int32_t>>;
1359+
unsigned rank = dotOperandShape.size();
1360+
auto standardOutDims = standardOutDimNames(ctx, rank);
1361+
auto warpOrder = mfmaEnc.getWarpOrder();
1362+
1363+
StringAttr kRegister = StringAttr::get(ctx, "register");
1364+
StringAttr kLane = StringAttr::get(ctx, "lane");
1365+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1366+
1367+
basisT regBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}};
1368+
basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}};
1369+
int32_t kTileSize;
1370+
if (mDim == 16) {
1371+
regBase.emplace_back(std::vector<int32_t>{0, 64});
1372+
laneBase.emplace_back(std::vector<int32_t>{0, 16});
1373+
laneBase.emplace_back(std::vector<int32_t>{0, 32});
1374+
kTileSize = kWidth * 4;
1375+
} else {
1376+
assert(mDim == 32);
1377+
regBase.emplace_back(std::vector<int32_t>{0, 32});
1378+
laneBase.emplace_back(std::vector<int32_t>{16, 0});
1379+
laneBase.emplace_back(std::vector<int32_t>{0, 16});
1380+
kTileSize = kWidth * 2;
1381+
}
1382+
// Add repeats of registers along K dimension to register base vectors
1383+
int64_t kSize = dotOperandIdx == 0 ? dotOperandShape[1] : dotOperandShape[0];
1384+
for (int32_t elem = kTileSize; elem < kSize; elem *= 2) {
1385+
regBase.emplace_back(std::vector<int32_t>{0, elem});
1386+
}
1387+
1388+
// Order of dimensionality changes on A/B operand, so here we need to reverse
1389+
// if it's operand B.
1390+
std::vector<int> repOrder = {0, 1};
1391+
if (dotOperandIdx == 1) {
1392+
std::reverse(repOrder.begin(), repOrder.end());
1393+
}
1394+
1395+
auto regLanes = LinearLayout(
1396+
{{kRegister, regBase}, {kLane, laneBase}},
1397+
{standardOutDims[repOrder[0]], standardOutDims[repOrder[1]]});
1398+
1399+
auto warps = identityStandardND(kWarp, mfmaEnc.getWarpsPerCTA(), warpOrder);
1400+
1401+
return combineCtaCgaWithShape(regLanes.transposeOuts(standardOutDims) *
1402+
warps.transposeOuts(standardOutDims),
1403+
mfmaEnc.getCTALayout(), dotOperandShape);
1404+
}
1405+
1406+
LinearLayout chooseScaledMfmaScaleLayout(
1407+
MLIRContext *ctx, int dotOperandIdx,
1408+
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
1409+
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim) {
1410+
using basisT = std::vector<std::vector<int32_t>>;
1411+
unsigned rank = dotOperandShape.size();
1412+
auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true);
1413+
auto standardOutDims = standardOutDimNames(ctx, rank);
1414+
StringAttr kRegister = StringAttr::get(ctx, "register");
1415+
StringAttr kLane = StringAttr::get(ctx, "lane");
1416+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1417+
StringAttr kBlock = StringAttr::get(ctx, "block");
1418+
// Init register layout. Will be adjusted later
1419+
auto regs = mlir::triton::identityStandardND(kRegister, {1, 1}, order);
1420+
LinearLayout lanes = LinearLayout::empty();
1421+
// In scaled dot, the shapes of operands(without batch dimension) are,
1422+
// respectively:
1423+
// - A: [M, K]
1424+
// - B: [K, N]
1425+
// - aScale: [M, K / 32]
1426+
// - bScale: [N, K / 32]
1427+
//
1428+
// To correctly feed A/B and its scale into instruction, we need to
1429+
// distribute aScale/bScale among warps in the same way as A/B. But bScale
1430+
// is not transposed like B. So we need to transpose the warp layout of
1431+
// bScale.
1432+
//
1433+
// The tricky part is, our desired outputs are [dim0, dim1], but
1434+
// at this position, the layouts are transposed to [dim1, dim0]. So
1435+
// instead of reverse bScale's layout, we need to reverse aScale's. There
1436+
// will be a transpose in the end to correct everything.
1437+
basisT warps = dotOperandWarpBasis;
1438+
if (dotOperandIdx == 0) {
1439+
for (auto &basis : warps) {
1440+
std::reverse(basis.begin(), basis.end());
1441+
}
1442+
}
1443+
// In general, for both 32x32 and 16x16 scaled mfma, and no matter what
1444+
// data type the A/B operand is, each lane takes 32 elements from A/B
1445+
// alone K dim, and 1 or 2 elements from scale accordingly. The number of
1446+
// scale's elements in a lane varies because the 32 elements from A/B may
1447+
// not be consecutive.
1448+
//
1449+
// For mxfp4, these 32 elements are consecutive, so only 1 scale element
1450+
// is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements
1451+
// blocks, so 2 scale elements are required.
1452+
if (mfmaMDim == 32) {
1453+
// For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane
1454+
// takes 32 consecutive elements from A alone K dimension. The first
1455+
// 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes
1456+
// collectively handle A[0:32][32:64]. Each lane take 1 scale element
1457+
// accordingly. Similar to B and bScale.
1458+
lanes = LinearLayout(
1459+
{{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}},
1460+
{kWarp, warps},
1461+
{kBlock, {}}},
1462+
{standardOutDims[order[0]], standardOutDims[order[1]]});
1463+
} else {
1464+
assert(mfmaMDim == 16);
1465+
// For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane
1466+
// takes 32 consecutive elements from A alone K dimension. The first
1467+
// 16 lanes collectively handle A[0:16][0:32], and another 16 lanes
1468+
// collectively handle A[0:16][32:64] and so on. Each lane take 1 scale
1469+
// element accordingly. Similar to B and bScale.
1470+
lanes =
1471+
LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
1472+
{kWarp, warps},
1473+
{kBlock, {}}},
1474+
{standardOutDims[order[0]], standardOutDims[order[1]]});
1475+
}
1476+
LinearLayout newLL = regs * lanes;
1477+
1478+
// Adjust register-level layout to fill the shape, at this level, both
1479+
// aScale and bScale should align with A operand.
1480+
SmallVector<int, 2> repOrder = {1, 0};
1481+
for (auto d : repOrder) {
1482+
auto outDim = standardOutDims[d];
1483+
auto dimSize = newLL.getOutDimSize(outDim);
1484+
newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister,
1485+
outDim);
1486+
}
1487+
newLL = newLL.transposeOuts(standardOutDims);
1488+
return newLL;
1489+
}
1490+
13381491
} // namespace mlir::triton::gpu

python/test/unit/language/test_core.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
is_hip_cdna,
3535
is_hip_mi200,
3636
is_hip_mi300,
37+
is_hip_mi350,
3738
is_xpu,
3839
get_arch,
3940
torch_float8_dtypes,
@@ -3764,8 +3765,8 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, nu
37643765
if not is_hip_cdna():
37653766
pytest.skip("scaled_dot only implemented for HIP CDNA")
37663767
if "e4m3" in (mxfp_type, normal_type):
3767-
if not is_hip_mi300():
3768-
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300")
3768+
if not (is_hip_mi300() or is_hip_mi350()):
3769+
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300 and MI350")
37693770
if mma == 16 and K == 64:
37703771
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
37713772

@@ -3938,7 +3939,15 @@ def make_arg(shape, ty, col_major=False):
39383939
# Clamp to avoid relative error issues
39393940
ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1)
39403941
else:
3941-
ret = torch.randint(256, shape, dtype=torch.uint8, device=device)
3942+
if is_hip_mi350():
3943+
# On other chips, the A/B operands are upcasted to fp16/bf16
3944+
# before matmul, which has larger range to avoid overflow.
3945+
# On MI350, we use the V_MFMA_*_F8F6F4 instructions to
3946+
# directly calculate matmul on F8F6F4 data. So we need
3947+
# to narrow down the range of input to avoid overflow.
3948+
ret = torch.randint(20, 40, shape, dtype=torch.uint8, device=device)
3949+
else:
3950+
ret = torch.randint(256, shape, dtype=torch.uint8, device=device)
39423951
if col_major:
39433952
ret = ret.mT
39443953
return ret

0 commit comments

Comments
 (0)