|
12 | 12 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
13 | 13 | #include "triton/Dialect/TritonGPU/Transforms/Passes.h" |
14 | 14 | #include "triton/Dialect/TritonGPU/Transforms/Utility.h" |
| 15 | +#include "triton/Tools/StrUtil.h" |
15 | 16 | #include "llvm/ADT/ArrayRef.h" |
16 | 17 | #include "llvm/ADT/SmallVector.h" |
17 | 18 |
|
@@ -394,6 +395,10 @@ class DecomposeScaledBlocked |
394 | 395 | auto aType = scaledDotOp.getLhsType(); |
395 | 396 | auto bType = scaledDotOp.getRhsType(); |
396 | 397 |
|
| 398 | + auto rank = oldRetType.getShape().size(); |
| 399 | + if (rank != 2) |
| 400 | + return rewriter.notifyMatchFailure(scaledDotOp, "NYI: rank==3"); |
| 401 | + |
397 | 402 | assert((aType == ScaleDotElemType::E4M3 || |
398 | 403 | aType == ScaleDotElemType::E5M2 || |
399 | 404 | aType == ScaleDotElemType::E2M1) && |
@@ -430,71 +435,95 @@ class DecomposeScaledBlocked |
430 | 435 | // `bases[warps] = {(0, 0), (0, 0), ...}` |
431 | 436 |
|
432 | 437 | auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, aKWidth); |
433 | | - auto rank = mmaEnc.getInstrShape().size(); |
| 438 | + |
434 | 439 | // MMAv3 uses the first dimension for the M dimension, while MMAv2 uses the |
435 | 440 | // penultimate (ugh) |
436 | | - auto instrShapeM = mmaEnc.getInstrShape()[versionMajor == 3 ? 0 : rank - 2]; |
| 441 | + auto instrShapeM = |
| 442 | + mmaEnc.getInstrShape()[versionMajor == 3 |
| 443 | + ? 0 |
| 444 | + : mmaEnc.getInstrShape().size() - 2]; |
437 | 445 | auto warpSize = getWarpSize(newAEncoding); |
438 | 446 | assert(instrShapeM <= warpSize); |
439 | 447 | // Necessary choice to leave all the scales of the tile in that given warp |
440 | 448 | auto threadsPerWarp = |
441 | 449 | SmallVector<unsigned>{instrShapeM, warpSize / instrShapeM}; |
442 | 450 |
|
443 | | - assert(versionMajor == 2 && |
444 | | - "NYI: MMAv3. Need to rethink the scale layout otherwise"); |
445 | | - |
446 | | - // Copy the bases |
447 | | - |
| 451 | + // This has to align with the order in UpcastMXFPOp |
| 452 | + auto order = getMatrixOrder(rank, /*rowMajor=*/true); |
448 | 453 | Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( |
449 | | - ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), |
450 | | - newAEncoding.getCTAOrder(), mmaEnc.getCTALayout()); |
| 454 | + ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), order, |
| 455 | + mmaEnc.getCTALayout()); |
451 | 456 |
|
| 457 | + // Lezcano: In the future we could just use the LLs unconditionally |
| 458 | + // Not doing it now as they are not as performant as Blocked encoding at |
| 459 | + // times E.g., we bail on them in the backwardMaterialization pass |
452 | 460 | auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1; |
453 | 461 | if (dotBroadcastsWarpLevel) { |
454 | | - // If mma has warpsPerCTA == {2, 2}, then newAEncoding has |
455 | | - // warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps |
456 | | - // on the second dimension as per |
457 | | - // A: 0 1 | 0 1 |
458 | | - // - - | - - |
459 | | - // 2 3 | 2 3 |
460 | | - // This broadcasting is not representable by standard blocked encodings, |
461 | | - // so we need to use linear layouts. |
462 | | - // This broadcasting is implemented in ampereDotToLinearLayout |
463 | | - auto blocked = cast<BlockedEncodingAttr>(newScaleEncoding); |
464 | | - auto blockedLL = *blocked.toLinearLayout(a.getType().getShape()); |
465 | | - LinearLayout::BasesT scaleBases = blockedLL.getBases(); |
466 | | - auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]); |
467 | | - auto &warps = scaleBases[StringAttr::get(ctx, "warp")]; |
468 | | - // Prepend the vector of zeros to the warpBases |
469 | | - warps.insert(warps.begin(), nBases, std::vector<int32_t>(rank, 0)); |
470 | | - auto outDims = llvm::to_vector(blockedLL.getOutDimNames()); |
471 | | - auto newLL = LinearLayout(scaleBases, outDims); |
472 | | - auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); |
473 | | - // Adjust the shape of the layout to match the scale operand |
474 | | - auto scaleShape = scale.getType().getShape(); |
475 | | - newScaleEncoding = |
476 | | - LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape)); |
| 462 | + auto kRegister = StringAttr::get(ctx, "register"); |
| 463 | + auto regs = identityStandardND(kRegister, {1, 1}, order); |
| 464 | + auto lanes = |
| 465 | + identityStandardND(StringAttr::get(ctx, "lane"), {16, 2}, order); |
| 466 | + |
| 467 | + // Extract warp layout from dotAEncoding |
| 468 | + // In the future we'll have some nice division utils, but until then... |
| 469 | + auto dotLL = *newAEncoding.toLinearLayout(a.getType().getShape()); |
| 470 | + LinearLayout::BasesT scaleBases = dotLL.getBases(); |
| 471 | + auto kWarp = StringAttr::get(ctx, "warp"); |
| 472 | + auto &warpBases = scaleBases[kWarp]; |
| 473 | + // The tile shape was [16, 2 * 4 * kWidth] with broadcasting in K |
| 474 | + // We divide the M dimension by 16 |
| 475 | + auto div = 16; |
| 476 | + for (auto &warpBase : warpBases) { |
| 477 | + if (warpBase[rank - 2] != 0) { |
| 478 | + assert(warpBase[rank - 2] % div == 0); |
| 479 | + warpBase[rank - 2] /= div; |
| 480 | + } |
| 481 | + } |
| 482 | + |
| 483 | + LinearLayout::BasesT warpBlockBases; |
| 484 | + auto standardOutDims = llvm::to_vector(dotLL.getOutDimNames()); |
| 485 | + warpBlockBases[kWarp] = warpBases; |
| 486 | + auto kBlock = StringAttr::get(ctx, "block"); |
| 487 | + assert(scaleBases[kBlock].empty() && "NYI: CGAs"); |
| 488 | + warpBlockBases[kBlock] = {}; |
| 489 | + auto warpBlock = LinearLayout(std::move(warpBlockBases), standardOutDims); |
| 490 | + |
| 491 | + auto newLL = |
| 492 | + (regs * lanes) * |
| 493 | + warpBlock.transposeOuts(llvm::to_vector(lanes.getOutDimNames())); |
| 494 | + auto shape = scale.getType().getShape(); |
| 495 | + |
| 496 | + // Broadcast to the correct shape Equivalent to |
| 497 | + // newLL = ensureLayoutNotSmallerThan(newLL.transposeOuts(getRepOrder), |
| 498 | + // shape); |
| 499 | + for (auto d : newAEncoding.getRepOrder()) { |
| 500 | + auto outDim = standardOutDims[d]; |
| 501 | + auto dimSize = newLL.getOutDimSize(outDim); |
| 502 | + newLL *= |
| 503 | + LinearLayout::identity1D(shape[d] / dimSize, kRegister, outDim); |
| 504 | + } |
| 505 | + newLL = newLL.transposeOuts(standardOutDims); |
| 506 | + newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); |
477 | 507 | } |
478 | 508 |
|
479 | 509 | a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); |
480 | 510 |
|
481 | | - // Upcast B operand |
482 | | - assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); |
483 | | - auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth); |
484 | | - b = createArg(rewriter, b, 1, bType, newBEncoding, |
485 | | - /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt); |
486 | 511 | Operation *newDot = nullptr; |
487 | 512 | if (versionMajor == 2) { |
| 513 | + // Upcast B operand |
| 514 | + assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); |
| 515 | + auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth); |
| 516 | + b = createArg(rewriter, b, 1, bType, newBEncoding, |
| 517 | + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt); |
488 | 518 | newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), newRetType, a, b, |
489 | 519 | newAcc); |
490 | 520 | } else { |
491 | 521 | assert(versionMajor == 3); |
492 | 522 | // At the time of this writing, this is always true |
493 | 523 | auto allowTranspose = b.getType().getElementType().isBF16(); |
494 | | - b = cast<TypedValue<RankedTensorType>>( |
495 | | - getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose)); |
| 524 | + auto bShmem = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); |
496 | 525 | newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>( |
497 | | - scaledDotOp.getLoc(), newRetType, a, b, newAcc, nullptr); |
| 526 | + scaledDotOp.getLoc(), newRetType, a, bShmem, newAcc, nullptr); |
498 | 527 | } |
499 | 528 |
|
500 | 529 | // convert dot instruction |
@@ -578,11 +607,11 @@ class DecomposeScaledBlocked |
578 | 607 | auto dotOp = rewriter.create<DotOp>( |
579 | 608 | scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC()); |
580 | 609 |
|
581 | | - // Waiting for https://github.com/triton-lang/triton/pull/5003 to land |
582 | | - // cf. |
583 | | - // https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746 |
584 | | - // int versionMajor = getMMAVersionSafe(computeCapability, dotOp); |
585 | 610 | int versionMajor = 2; |
| 611 | + // We just support bf16 for MMAv3 on the rhs |
| 612 | + if (bType == ScaleDotElemType::BF16) { |
| 613 | + versionMajor = getMMAVersionSafe(computeCapability, dotOp); |
| 614 | + } |
586 | 615 | int versionMinor = computeCapability == 75 ? 1 : 0; |
587 | 616 |
|
588 | 617 | RankedTensorType oldRetType = dotOp.getType(); |
|
0 commit comments