|
1 | 1 | #include <vector> |
2 | 2 |
|
| 3 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
3 | 4 | #include "triton/Dialect/Triton/IR/Utility.h" |
4 | 5 | #include "triton/Dialect/TritonGPU/IR/Attributes.h" |
5 | 6 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
|
14 | 15 | #include "llvm/Support/ErrorHandling.h" |
15 | 16 | #include "llvm/Support/MathExtras.h" |
16 | 17 |
|
| 18 | +using mlir::triton::ScaleDotElemType; |
| 19 | + |
17 | 20 | namespace mlir::triton::gpu { |
18 | 21 | namespace { |
19 | 22 |
|
@@ -1335,4 +1338,154 @@ LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape, |
1335 | 1338 | return chooseDotDsReadB64Tr16Layout(dot, shape, elemBitWidth); |
1336 | 1339 | } |
1337 | 1340 |
|
| 1341 | +LinearLayout |
| 1342 | +chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth, |
| 1343 | + int dotOperandIdx, ScaleDotElemType elemType, |
| 1344 | + llvm::ArrayRef<int64_t> dotOperandShape) { |
| 1345 | + MLIRContext *ctx = mfmaEnc.getContext(); |
| 1346 | + unsigned mDim = mfmaEnc.getMDim(); |
| 1347 | + if (elemType == ScaleDotElemType::E2M1) { |
| 1348 | + auto newEncoding = |
| 1349 | + DotOperandEncodingAttr::get(ctx, dotOperandIdx, mfmaEnc, kWidth / 2); |
| 1350 | + return newEncoding.toLinearLayout(dotOperandShape); |
| 1351 | + } |
| 1352 | + |
| 1353 | + // For mxfp8, each lane contains 32 elements, consisting of two blocks |
| 1354 | + // of 16 consecutive elements. There's a gap between these two blocks, |
| 1355 | + // which is not supported by normal dot layout. |
| 1356 | + assert(elemType == ScaleDotElemType::E4M3 || |
| 1357 | + elemType == ScaleDotElemType::E5M2); |
| 1358 | + using basisT = std::vector<std::vector<int32_t>>; |
| 1359 | + unsigned rank = dotOperandShape.size(); |
| 1360 | + auto standardOutDims = standardOutDimNames(ctx, rank); |
| 1361 | + auto warpOrder = mfmaEnc.getWarpOrder(); |
| 1362 | + |
| 1363 | + StringAttr kRegister = StringAttr::get(ctx, "register"); |
| 1364 | + StringAttr kLane = StringAttr::get(ctx, "lane"); |
| 1365 | + StringAttr kWarp = StringAttr::get(ctx, "warp"); |
| 1366 | + |
| 1367 | + basisT regBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}}; |
| 1368 | + basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}}; |
| 1369 | + int32_t kTileSize; |
| 1370 | + if (mDim == 16) { |
| 1371 | + regBase.emplace_back(std::vector<int32_t>{0, 64}); |
| 1372 | + laneBase.emplace_back(std::vector<int32_t>{0, 16}); |
| 1373 | + laneBase.emplace_back(std::vector<int32_t>{0, 32}); |
| 1374 | + kTileSize = kWidth * 4; |
| 1375 | + } else { |
| 1376 | + assert(mDim == 32); |
| 1377 | + regBase.emplace_back(std::vector<int32_t>{0, 32}); |
| 1378 | + laneBase.emplace_back(std::vector<int32_t>{16, 0}); |
| 1379 | + laneBase.emplace_back(std::vector<int32_t>{0, 16}); |
| 1380 | + kTileSize = kWidth * 2; |
| 1381 | + } |
| 1382 | + // Add repeats of registers along K dimension to register base vectors |
| 1383 | + int64_t kSize = dotOperandIdx == 0 ? dotOperandShape[1] : dotOperandShape[0]; |
| 1384 | + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) { |
| 1385 | + regBase.emplace_back(std::vector<int32_t>{0, elem}); |
| 1386 | + } |
| 1387 | + |
| 1388 | + // Order of dimensionality changes on A/B operand, so here we need to reverse |
| 1389 | + // if it's operand B. |
| 1390 | + std::vector<int> repOrder = {0, 1}; |
| 1391 | + if (dotOperandIdx == 1) { |
| 1392 | + std::reverse(repOrder.begin(), repOrder.end()); |
| 1393 | + } |
| 1394 | + |
| 1395 | + auto regLanes = LinearLayout( |
| 1396 | + {{kRegister, regBase}, {kLane, laneBase}}, |
| 1397 | + {standardOutDims[repOrder[0]], standardOutDims[repOrder[1]]}); |
| 1398 | + |
| 1399 | + auto warps = identityStandardND(kWarp, mfmaEnc.getWarpsPerCTA(), warpOrder); |
| 1400 | + |
| 1401 | + return combineCtaCgaWithShape(regLanes.transposeOuts(standardOutDims) * |
| 1402 | + warps.transposeOuts(standardOutDims), |
| 1403 | + mfmaEnc.getCTALayout(), dotOperandShape); |
| 1404 | +} |
| 1405 | + |
| 1406 | +LinearLayout chooseScaledMfmaScaleLayout( |
| 1407 | + MLIRContext *ctx, int dotOperandIdx, |
| 1408 | + const std::vector<std::vector<int32_t>> &dotOperandWarpBasis, |
| 1409 | + ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim) { |
| 1410 | + using basisT = std::vector<std::vector<int32_t>>; |
| 1411 | + unsigned rank = dotOperandShape.size(); |
| 1412 | + auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); |
| 1413 | + auto standardOutDims = standardOutDimNames(ctx, rank); |
| 1414 | + StringAttr kRegister = StringAttr::get(ctx, "register"); |
| 1415 | + StringAttr kLane = StringAttr::get(ctx, "lane"); |
| 1416 | + StringAttr kWarp = StringAttr::get(ctx, "warp"); |
| 1417 | + StringAttr kBlock = StringAttr::get(ctx, "block"); |
| 1418 | + // Init register layout. Will be adjusted later |
| 1419 | + auto regs = mlir::triton::identityStandardND(kRegister, {1, 1}, order); |
| 1420 | + LinearLayout lanes = LinearLayout::empty(); |
| 1421 | + // In scaled dot, the shapes of operands(without batch dimension) are, |
| 1422 | + // respectively: |
| 1423 | + // - A: [M, K] |
| 1424 | + // - B: [K, N] |
| 1425 | + // - aScale: [M, K / 32] |
| 1426 | + // - bScale: [N, K / 32] |
| 1427 | + // |
| 1428 | + // To correctly feed A/B and its scale into instruction, we need to |
| 1429 | + // distribute aScale/bScale among warps in the same way as A/B. But bScale |
| 1430 | + // is not transposed like B. So we need to transpose the warp layout of |
| 1431 | + // bScale. |
| 1432 | + // |
| 1433 | + // The tricky part is, our desired outputs are [dim0, dim1], but |
| 1434 | + // at this position, the layouts are transposed to [dim1, dim0]. So |
| 1435 | + // instead of reverse bScale's layout, we need to reverse aScale's. There |
| 1436 | + // will be a transpose in the end to correct everything. |
| 1437 | + basisT warps = dotOperandWarpBasis; |
| 1438 | + if (dotOperandIdx == 0) { |
| 1439 | + for (auto &basis : warps) { |
| 1440 | + std::reverse(basis.begin(), basis.end()); |
| 1441 | + } |
| 1442 | + } |
| 1443 | + // In general, for both 32x32 and 16x16 scaled mfma, and no matter what |
| 1444 | + // data type the A/B operand is, each lane takes 32 elements from A/B |
| 1445 | + // alone K dim, and 1 or 2 elements from scale accordingly. The number of |
| 1446 | + // scale's elements in a lane varies because the 32 elements from A/B may |
| 1447 | + // not be consecutive. |
| 1448 | + // |
| 1449 | + // For mxfp4, these 32 elements are consecutive, so only 1 scale element |
| 1450 | + // is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements |
| 1451 | + // blocks, so 2 scale elements are required. |
| 1452 | + if (mfmaMDim == 32) { |
| 1453 | + // For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane |
| 1454 | + // takes 32 consecutive elements from A alone K dimension. The first |
| 1455 | + // 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes |
| 1456 | + // collectively handle A[0:32][32:64]. Each lane take 1 scale element |
| 1457 | + // accordingly. Similar to B and bScale. |
| 1458 | + lanes = LinearLayout( |
| 1459 | + {{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}}, |
| 1460 | + {kWarp, warps}, |
| 1461 | + {kBlock, {}}}, |
| 1462 | + {standardOutDims[order[0]], standardOutDims[order[1]]}); |
| 1463 | + } else { |
| 1464 | + assert(mfmaMDim == 16); |
| 1465 | + // For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane |
| 1466 | + // takes 32 consecutive elements from A alone K dimension. The first |
| 1467 | + // 16 lanes collectively handle A[0:16][0:32], and another 16 lanes |
| 1468 | + // collectively handle A[0:16][32:64] and so on. Each lane take 1 scale |
| 1469 | + // element accordingly. Similar to B and bScale. |
| 1470 | + lanes = |
| 1471 | + LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, |
| 1472 | + {kWarp, warps}, |
| 1473 | + {kBlock, {}}}, |
| 1474 | + {standardOutDims[order[0]], standardOutDims[order[1]]}); |
| 1475 | + } |
| 1476 | + LinearLayout newLL = regs * lanes; |
| 1477 | + |
| 1478 | + // Adjust register-level layout to fill the shape, at this level, both |
| 1479 | + // aScale and bScale should align with A operand. |
| 1480 | + SmallVector<int, 2> repOrder = {1, 0}; |
| 1481 | + for (auto d : repOrder) { |
| 1482 | + auto outDim = standardOutDims[d]; |
| 1483 | + auto dimSize = newLL.getOutDimSize(outDim); |
| 1484 | + newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister, |
| 1485 | + outDim); |
| 1486 | + } |
| 1487 | + newLL = newLL.transposeOuts(standardOutDims); |
| 1488 | + return newLL; |
| 1489 | +} |
| 1490 | + |
1338 | 1491 | } // namespace mlir::triton::gpu |
0 commit comments