Skip to content

Commit 6ab8757

Browse files
Merge commit '5e59bdfed405c8020b77f4af94e1d04740de967f'
2 parents 6b6d9ac + 5e59bdf commit 6ab8757

File tree

27 files changed

+455
-428
lines changed

27 files changed

+455
-428
lines changed

.github/workflows/integration-tests.yml

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ jobs:
327327
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}}
328328
name: Integration-Tests (${{matrix.runner[1] == 'gfx90a' && 'mi210' || 'mi300x'}})
329329
container:
330-
image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4
330+
image: rocmshared/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
331331
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
332332
steps:
333333
- name: Checkout
@@ -396,22 +396,15 @@ jobs:
396396
397397
mkdir -p ~/.ccache
398398
du -h -d 1 ~/.ccache
399-
- name: Update PATH
400-
run: |
401-
echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH
402-
- name: Install pip dependencies
403-
run: |
404-
python3 -m pip install --upgrade pip
405-
python3 -m pip install lit
406-
- name: Install apt dependencies
399+
- name: Update compiler to clang
407400
run: |
408-
apt update
409-
apt install ccache
401+
export CC=/usr/bin/clang
402+
export CXX=/usr/bin/clang++
410403
- name: Install Triton
411404
id: amd-install-triton
412405
run: |
413406
echo "PATH is '$PATH'"
414-
pip uninstall -y triton
407+
pip uninstall -y triton pytorch-triton-rocm
415408
cd python
416409
ccache --zero-stats
417410
pip install -v -e '.[tests]'

.github/workflows/integration-tests.yml.in

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ jobs:
374374
name: Integration-Tests (${{matrix.runner[1] == 'gfx90a' && 'mi210' || 'mi300x'}})
375375

376376
container:
377-
image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4
377+
image: rocmshared/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
378378
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
379379

380380
steps:
@@ -388,25 +388,16 @@ jobs:
388388
- *restore-build-artifacts-step
389389
- *inspect-cache-directories-step
390390

391-
- name: Update PATH
392-
run: |
393-
echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH
394-
395-
- name: Install pip dependencies
396-
run: |
397-
python3 -m pip install --upgrade pip
398-
python3 -m pip install lit
399-
400-
- name: Install apt dependencies
391+
- name: Update compiler to clang
401392
run: |
402-
apt update
403-
apt install ccache
393+
export CC=/usr/bin/clang
394+
export CXX=/usr/bin/clang++
404395

405396
- name: Install Triton
406397
id: amd-install-triton
407398
run: |
408399
echo "PATH is '$PATH'"
409-
pip uninstall -y triton
400+
pip uninstall -y triton pytorch-triton-rocm
410401
cd python
411402
ccache --zero-stats
412403
pip install -v -e '.[tests]'

CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ if(NOT WIN32)
8484
endif()
8585

8686
if(TRITON_BUILD_UT)
87+
# This is an aggregate target for all unit tests.
88+
add_custom_target(TritonUnitTests)
89+
set_target_properties(TritonUnitTests PROPERTIES FOLDER "Triton/Tests")
8790
include(AddTritonUnitTest)
8891
endif()
8992

@@ -355,4 +358,10 @@ add_subdirectory(test)
355358

356359
if(TRITON_BUILD_UT)
357360
add_subdirectory(unittest)
361+
# This target runs all the unit tests.
362+
add_custom_target(check-triton-unit-tests
363+
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
364+
DEPENDS TritonUnitTests
365+
USES_TERMINAL
366+
)
358367
endif()

cmake/AddTritonUnitTest.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,7 @@ function(add_triton_ut)
3636
# laptop. I think the issue may be that the very first time you run a program
3737
# it's a bit slow.
3838
gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60)
39+
40+
# Add the unit test to the top-level unit test target.
41+
add_dependencies(TritonUnitTests ${__NAME})
3942
endfunction()

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ LinearLayout ensureLayoutNotSmallerThan(
214214
const LinearLayout &layout,
215215
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
216216

217+
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
218+
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
219+
ArrayRef<unsigned> order);
220+
217221
// Dump information about which threads/registers contain each of the tensor
218222
// elements.
219223
void dumpLayout(RankedTensorType tensorType);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
9393
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
94-
if (dotParent && dotParent.isAmpere()) {
94+
if (dotParent) {
9595
return;
9696
}
9797
Attribute sharedMemorySpace =

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,36 @@ LinearLayout ensureLayoutNotSmallerThan(
658658
return ret;
659659
}
660660

661+
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
662+
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
663+
SmallVector<StringAttr> ret;
664+
for (int i = 0; i < rank; i++) {
665+
ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i)));
666+
}
667+
return ret;
668+
}
669+
670+
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
671+
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
672+
// permute(shape, order).
673+
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
674+
ArrayRef<unsigned> order) {
675+
assert(shape.size() == order.size());
676+
MLIRContext *ctx = inDimName.getContext();
677+
auto rank = shape.size();
678+
679+
// The order in triton is written wrt. [dim0, dim1, ...].
680+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
681+
682+
LinearLayout ret = LinearLayout::empty();
683+
for (int i = 0; i < shape.size(); i++) {
684+
// Start with the most-minor dimension, which is order[0].
685+
int dim = order[i];
686+
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
687+
}
688+
return ret;
689+
}
690+
661691
} // namespace gpu
662692
} // namespace triton
663693
} // namespace mlir

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@ namespace {
3333

3434
#define S(v) StringAttr::get(ctx, (v))
3535

36-
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
37-
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
38-
SmallVector<StringAttr> ret;
39-
for (int i = 0; i < rank; i++) {
40-
ret.push_back(S("dim" + llvm::Twine(i)));
41-
}
42-
return ret;
43-
}
44-
4536
// TODO Have order be a mandatory argument of standardOutDimNames.
4637
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
4738
const SmallVector<unsigned> &order) {
@@ -53,27 +44,6 @@ SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
5344
return ret;
5445
}
5546

56-
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
57-
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
58-
// permute(shape, order).
59-
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
60-
ArrayRef<unsigned> order) {
61-
assert(shape.size() == order.size());
62-
MLIRContext *ctx = inDimName.getContext();
63-
auto rank = shape.size();
64-
65-
// The order in triton is written wrt. [dim0, dim1, ...].
66-
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
67-
68-
LinearLayout ret = LinearLayout::empty();
69-
for (int i = 0; i < shape.size(); i++) {
70-
// Start with the most-minor dimension, which is order[0].
71-
int dim = order[i];
72-
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
73-
}
74-
return ret;
75-
}
76-
7747
// Make a LinearLayout that maps a block-id to an N-dimensional index.
7848
//
7949
// The tensor is split up into CTAsPerCGA pieces, which are distributed among

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1313
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1414
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
15+
#include "triton/Tools/StrUtil.h"
1516
#include "llvm/ADT/ArrayRef.h"
1617
#include "llvm/ADT/SmallVector.h"
1718

@@ -394,6 +395,10 @@ class DecomposeScaledBlocked
394395
auto aType = scaledDotOp.getLhsType();
395396
auto bType = scaledDotOp.getRhsType();
396397

398+
auto rank = oldRetType.getShape().size();
399+
if (rank != 2)
400+
return rewriter.notifyMatchFailure(scaledDotOp, "NYI: rank==3");
401+
397402
assert((aType == ScaleDotElemType::E4M3 ||
398403
aType == ScaleDotElemType::E5M2 ||
399404
aType == ScaleDotElemType::E2M1) &&
@@ -430,71 +435,95 @@ class DecomposeScaledBlocked
430435
// `bases[warps] = {(0, 0), (0, 0), ...}`
431436

432437
auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, aKWidth);
433-
auto rank = mmaEnc.getInstrShape().size();
438+
434439
// MMAv3 uses the first dimension for the M dimension, while MMAv2 uses the
435440
// penultimate (ugh)
436-
auto instrShapeM = mmaEnc.getInstrShape()[versionMajor == 3 ? 0 : rank - 2];
441+
auto instrShapeM =
442+
mmaEnc.getInstrShape()[versionMajor == 3
443+
? 0
444+
: mmaEnc.getInstrShape().size() - 2];
437445
auto warpSize = getWarpSize(newAEncoding);
438446
assert(instrShapeM <= warpSize);
439447
// Necessary choice to leave all the scales of the tile in that given warp
440448
auto threadsPerWarp =
441449
SmallVector<unsigned>{instrShapeM, warpSize / instrShapeM};
442450

443-
assert(versionMajor == 2 &&
444-
"NYI: MMAv3. Need to rethink the scale layout otherwise");
445-
446-
// Copy the bases
447-
451+
// This has to align with the order in UpcastMXFPOp
452+
auto order = getMatrixOrder(rank, /*rowMajor=*/true);
448453
Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
449-
ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(),
450-
newAEncoding.getCTAOrder(), mmaEnc.getCTALayout());
454+
ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), order,
455+
mmaEnc.getCTALayout());
451456

457+
// Lezcano: In the future we could just use the LLs unconditionally
458+
// Not doing it now as they are not as performant as Blocked encoding at
459+
// times E.g., we bail on them in the backwardMaterialization pass
452460
auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1;
453461
if (dotBroadcastsWarpLevel) {
454-
// If mma has warpsPerCTA == {2, 2}, then newAEncoding has
455-
// warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps
456-
// on the second dimension as per
457-
// A: 0 1 | 0 1
458-
// - - | - -
459-
// 2 3 | 2 3
460-
// This broadcasting is not representable by standard blocked encodings,
461-
// so we need to use linear layouts.
462-
// This broadcasting is implemented in ampereDotToLinearLayout
463-
auto blocked = cast<BlockedEncodingAttr>(newScaleEncoding);
464-
auto blockedLL = *blocked.toLinearLayout(a.getType().getShape());
465-
LinearLayout::BasesT scaleBases = blockedLL.getBases();
466-
auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]);
467-
auto &warps = scaleBases[StringAttr::get(ctx, "warp")];
468-
// Prepend the vector of zeros to the warpBases
469-
warps.insert(warps.begin(), nBases, std::vector<int32_t>(rank, 0));
470-
auto outDims = llvm::to_vector(blockedLL.getOutDimNames());
471-
auto newLL = LinearLayout(scaleBases, outDims);
472-
auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
473-
// Adjust the shape of the layout to match the scale operand
474-
auto scaleShape = scale.getType().getShape();
475-
newScaleEncoding =
476-
LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape));
462+
auto kRegister = StringAttr::get(ctx, "register");
463+
auto regs = identityStandardND(kRegister, {1, 1}, order);
464+
auto lanes =
465+
identityStandardND(StringAttr::get(ctx, "lane"), {16, 2}, order);
466+
467+
// Extract warp layout from dotAEncoding
468+
// In the future we'll have some nice division utils, but until then...
469+
auto dotLL = *newAEncoding.toLinearLayout(a.getType().getShape());
470+
LinearLayout::BasesT scaleBases = dotLL.getBases();
471+
auto kWarp = StringAttr::get(ctx, "warp");
472+
auto &warpBases = scaleBases[kWarp];
473+
// The tile shape was [16, 2 * 4 * kWidth] with broadcasting in K
474+
// We divide the M dimension by 16
475+
auto div = 16;
476+
for (auto &warpBase : warpBases) {
477+
if (warpBase[rank - 2] != 0) {
478+
assert(warpBase[rank - 2] % div == 0);
479+
warpBase[rank - 2] /= div;
480+
}
481+
}
482+
483+
LinearLayout::BasesT warpBlockBases;
484+
auto standardOutDims = llvm::to_vector(dotLL.getOutDimNames());
485+
warpBlockBases[kWarp] = warpBases;
486+
auto kBlock = StringAttr::get(ctx, "block");
487+
assert(scaleBases[kBlock].empty() && "NYI: CGAs");
488+
warpBlockBases[kBlock] = {};
489+
auto warpBlock = LinearLayout(std::move(warpBlockBases), standardOutDims);
490+
491+
auto newLL =
492+
(regs * lanes) *
493+
warpBlock.transposeOuts(llvm::to_vector(lanes.getOutDimNames()));
494+
auto shape = scale.getType().getShape();
495+
496+
// Broadcast to the correct shape Equivalent to
497+
// newLL = ensureLayoutNotSmallerThan(newLL.transposeOuts(getRepOrder),
498+
// shape);
499+
for (auto d : newAEncoding.getRepOrder()) {
500+
auto outDim = standardOutDims[d];
501+
auto dimSize = newLL.getOutDimSize(outDim);
502+
newLL *=
503+
LinearLayout::identity1D(shape[d] / dimSize, kRegister, outDim);
504+
}
505+
newLL = newLL.transposeOuts(standardOutDims);
506+
newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
477507
}
478508

479509
a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding);
480510

481-
// Upcast B operand
482-
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
483-
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
484-
b = createArg(rewriter, b, 1, bType, newBEncoding,
485-
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
486511
Operation *newDot = nullptr;
487512
if (versionMajor == 2) {
513+
// Upcast B operand
514+
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
515+
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
516+
b = createArg(rewriter, b, 1, bType, newBEncoding,
517+
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
488518
newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), newRetType, a, b,
489519
newAcc);
490520
} else {
491521
assert(versionMajor == 3);
492522
// At the time of this writing, this is always true
493523
auto allowTranspose = b.getType().getElementType().isBF16();
494-
b = cast<TypedValue<RankedTensorType>>(
495-
getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose));
524+
auto bShmem = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
496525
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
497-
scaledDotOp.getLoc(), newRetType, a, b, newAcc, nullptr);
526+
scaledDotOp.getLoc(), newRetType, a, bShmem, newAcc, nullptr);
498527
}
499528

500529
// convert dot instruction
@@ -578,11 +607,11 @@ class DecomposeScaledBlocked
578607
auto dotOp = rewriter.create<DotOp>(
579608
scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC());
580609

581-
// Waiting for https://github.com/triton-lang/triton/pull/5003 to land
582-
// cf.
583-
// https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746
584-
// int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
585610
int versionMajor = 2;
611+
// We just support bf16 for MMAv3 on the rhs
612+
if (bType == ScaleDotElemType::BF16) {
613+
versionMajor = getMMAVersionSafe(computeCapability, dotOp);
614+
}
586615
int versionMinor = computeCapability == 75 ? 1 : 0;
587616

588617
RankedTensorType oldRetType = dotOp.getType();

python/tutorials/02-fused-softmax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,12 @@ def allocated_slm_size(size_smem):
158158
y = torch.empty_like(x)
159159

160160
# pre-compile kernel to get register usage and compute thread occupancy.
161-
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
162-
if kernel is None:
163-
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
164-
threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
165-
kernel._init_handles()
166-
size_smem = kernel.metadata.shared
167-
num_programs = occupancy(num_warps, size_smem)
168-
kernels[BLOCK_SIZE] = (kernel, num_programs)
161+
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
162+
threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
163+
kernel._init_handles()
164+
size_smem = kernel.metadata.shared
165+
num_programs = occupancy(num_warps, size_smem)
166+
kernels[BLOCK_SIZE] = (kernel, num_programs)
169167

170168
# We will *not* launch a persistent kernel if the number of rows is lower (not needed) or that would imply each
171169
# program would need to process more than 2 rows. Persistent kernels save thread dispatch overhead, but cannot

0 commit comments

Comments
 (0)