Skip to content

Commit 8fdf504

Browse files
authored
Merge OpenAI Triton commit 4327b5b (#5565)
This PR changes the Triton base from a5b948c to 4327b5b (Nov 9). Pass rate: 95.19%
2 parents 96b9f83 + 91a9281 commit 8fdf504

File tree

90 files changed

+1988
-1432
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+1988
-1432
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,23 @@ jobs:
1818
matrix:
1919
runner: ${{ fromJson(inputs.matrix) }}
2020
include:
21-
- image: rocm/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
21+
- image: rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0
2222
runner: ["self-hosted", "gfx90a"]
2323
# Cache save/restore is on the host machine at directory /home/runner/.triton, while in the docker
2424
# container expect it at /github/home/.triton. So map here to make sure visible in docker.
2525
options: >-
2626
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
2727
--volume /home/runner/.triton:/github/home/.triton
28-
- image: rocm/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
28+
- image: rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0
2929
runner: ["amd-gfx942"]
3030
# We add --env-file to pull in HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES definition for GPU isolation.
3131
options: >-
3232
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
3333
--env-file /etc/podinfo/gha-gpu-isolation-settings
3434
--volume /home/runner/.triton:/github/home/.triton
35-
- image: rocm/7.0-preview:rocm7.0_preview_ubuntu22.04_llama2_70b_training_mlperf_mi35X_prealpha
35+
- image: rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0
3636
runner: ["amd-gfx950"]
37+
# We add --env-file to pull in HIP_VISIBLE_DEVICES and ROCR_VISIBLE_DEVICES definition for GPU isolation.
3738
options: >-
3839
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
3940
--env-file /etc/podinfo/gha-gpu-isolation-settings
@@ -83,14 +84,16 @@ jobs:
8384
~/.triton/nvidia
8485
~/.triton/json
8586
key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }}
87+
- name: Install dependencies
88+
run: apt-get install -y clang lld ccache
8689
- name: Inspect cache directories
8790
run: |
8891
mkdir -p ~/.triton
8992
du -h -d 1 ~/.triton
9093
9194
mkdir -p ~/.ccache
9295
du -h -d 1 ~/.ccache
93-
- name: Update compiler to clang
96+
- name: Update compiler to Clang
9497
run: |
9598
export CC=/usr/bin/clang
9699
export CXX=/usr/bin/clang++
@@ -100,19 +103,15 @@ jobs:
100103
echo "PATH is '$PATH'"
101104
pip uninstall -y triton pytorch-triton-rocm
102105
103-
if [ "${{ matrix.runner[0] }}" != "amd-gfx950" ]; then
104-
ccache --zero-stats
105-
fi
106-
106+
ccache --zero-stats
107107
make dev-install
108-
- name: CCache Stats
109-
if: ${{ matrix.runner[0] != 'amd-gfx950' }}
108+
- name: Print ccache stats
110109
run: ccache --print-stats
111110
- name: Run lit tests
112111
run: make test-lit
113112
- name: Run C++ unittests
114113
run: make test-cpp
115-
- name: Run python tests on AMD
114+
- name: Run Python tests on AMD
116115
run: |
117116
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
118117
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ test-interpret: all
7373

7474
.PHONY: test-proton
7575
test-proton: all
76-
$(PYTEST) --tb=short -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
76+
$(PYTEST) --tb=short -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py -k "not test_overhead"
7777
$(PYTEST) --tb=short -s third_party/proton/test/test_override.py
78+
$(PYTEST) --tb=short -s third_party/proton/test/test_instrumentation.py::test_overhead
7879

7980
.PHONY: test-python
8081
test-python: test-unit test-regression test-interpret test-proton

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ SetVector<int> getPartitionIds(Operation *op);
303303
SmallVector<SetVector<int>, 4> getPartitionOutputs(Operation *op);
304304
SetVector<int> getPartitionIds(OpOperand *use);
305305
bool hasPartition(Operation *op);
306+
bool hasWarpSpecializeTag(Operation *op);
307+
std::optional<int> getWarpSpecializeTag(Operation *op);
306308

307309
} // namespace mlir::triton::gpu
308310

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ void setPartition(Operation *op, const SetVector<Partition *> &partitions);
120120
void setPartition(Operation *op, const SetVector<int> &partitionIds);
121121
void setPartitionOutputs(Operation *op,
122122
ArrayRef<SetVector<int>> partitionOutputsIds);
123+
void setWarpSpecializeTag(Operation *op, int tag);
123124

124125
} // namespace mlir::triton::gpu
125126

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ inline const char *getOpShape(TMemAccessAtom atom) {
107107
llvm_unreachable("Unknown TMemAccessAtom");
108108
}
109109

110-
LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom,
111-
bool unpacked);
110+
LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, bool unpacked,
111+
bool withWarp);
112112

113113
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
114114

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ def TTG_TensorMemoryEncodingAttr : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemor
4040
that is, the stride between two elements in the same row.
4141
When colStride is 1 the tensor memory is packed. When colStride > 1, the
4242
tensor memory between elements is undefined.
43+
`twoCTAs` indicates that the tensor memory is laid out for twoCTA mode,
44+
i.e., `cta_group::2`.
4345
}];
4446
let parameters = (
4547
ins
4648
"unsigned":$blockM,
4749
"unsigned":$blockN,
4850
"unsigned":$colStride,
4951
DefaultValuedParameter<"unsigned", "1">:$CTASplitM,
50-
DefaultValuedParameter<"unsigned", "1">:$CTASplitN
52+
DefaultValuedParameter<"unsigned", "1">:$CTASplitN,
53+
DefaultValuedParameter<"bool", "false">:$twoCTAs
5154
);
5255
let genVerifyDecl = 1;
5356
let assemblyFormat = "`<` struct(params) `>`";

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,15 +2806,34 @@ struct TritonGPUInferLayoutInterface
28062806
mlir::dyn_cast<triton::gpu::DotOperandEncodingAttr>(operandEncodingB);
28072807
if (!aEncoding && !bEncoding)
28082808
return mlir::success();
2809-
auto mmaAEncoding =
2810-
mlir::dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
2811-
if (mmaAEncoding && mmaAEncoding.isHopper())
2812-
return success();
2813-
// Verify that the encodings are valid.
28142809
if (!aEncoding || !bEncoding)
28152810
return op->emitError("mismatching encoding between A and B operands");
2811+
// Verify that the encodings are valid.
28162812
if (aEncoding.getKWidth() != bEncoding.getKWidth())
28172813
return op->emitError("mismatching kWidth between A and B operands");
2814+
2815+
// Check if we have already selected an MMA version for Nvidia. If so,
2816+
// validate that the encodings are correct and compatible.
2817+
auto mmaAEncoding =
2818+
dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
2819+
auto mmaBEncoding =
2820+
dyn_cast_or_null<NvidiaMmaEncodingAttr>(bEncoding.getParent());
2821+
auto dotOp = cast<DotOp>(op);
2822+
auto resEnc = dotOp.getResult().getType().getEncoding();
2823+
auto mmaResEncoding = dyn_cast<NvidiaMmaEncodingAttr>(resEnc);
2824+
if (mmaAEncoding || mmaBEncoding || mmaResEncoding) {
2825+
// Check that they are all set and have the same version.
2826+
if (!mmaAEncoding || !mmaBEncoding || !mmaResEncoding)
2827+
return op->emitError("mismatching MMA encoding");
2828+
auto mmaBEncoding = cast<NvidiaMmaEncodingAttr>(bEncoding.getParent());
2829+
if (mmaAEncoding.getVersionMajor() != mmaBEncoding.getVersionMajor() ||
2830+
mmaAEncoding.getVersionMajor() != mmaResEncoding.getVersionMajor()) {
2831+
return op->emitError("mismatched MMA version.");
2832+
}
2833+
// Verify that the operands are supported on the selected MMA version.
2834+
if (!supportMMA(dotOp, mmaResEncoding.getVersionMajor()))
2835+
return op->emitError("unsupported MMA version");
2836+
}
28182837
return success();
28192838
}
28202839

@@ -4032,3 +4051,14 @@ SetVector<int> triton::gpu::getPartitionIds(OpOperand *use) {
40324051
bool triton::gpu::hasPartition(Operation *op) {
40334052
return op && op->hasAttr(kPartitionAttrName);
40344053
}
4054+
4055+
bool triton::gpu::hasWarpSpecializeTag(Operation *op) {
4056+
return op && op->hasAttr(kWarpSpecializeTagAttrName);
4057+
}
4058+
4059+
std::optional<int> triton::gpu::getWarpSpecializeTag(Operation *op) {
4060+
if (hasWarpSpecializeTag(op)) {
4061+
return cast<IntegerAttr>(op->getAttr(kWarpSpecializeTagAttrName)).getInt();
4062+
}
4063+
return std::nullopt;
4064+
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 31 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -470,93 +470,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
470470
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
471471
}
472472

473-
std::optional<LinearLayout>
474-
chooseLLDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
475-
int32_t elemBitWidth, unsigned instBitWidth,
476-
unsigned numLanesInShuffleGroup) {
477-
using BaseTy = std::vector<std::vector<int32_t>>;
478-
// This function will derive the layout for the ds_read_tr instruction
479-
// based on the input layout (LL/DotLayout/...)
480-
// The ds_read_tr instruction works on instBitWidth per lane and in groups of
481-
// numLanesInShuffleGroup lanes.
482-
483-
// In this example we look at ds_read_b64_tr (instBitWidth = 64) and
484-
// numLanesInShuffleGroup = 16 with 64 lanes per warp. Using M-continuous
485-
// 16-bit input tensor A as an example. Each lane will load 4 consecutive
486-
// elements (64-bit in total) along M. There are 4 consecutive lanes in total
487-
// along M. Then the loaded elements are exchanged within the MxK=16x4 "base
488-
// unit".
489-
// K0 K1 K2 K3
490-
// +---+---+---+---+
491-
// M0 | | | | | M0, K[0-3]: T0
492-
// M1 | T | T | T | T | M1, K[0-3]: T1
493-
// M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
494-
// M3 | | | | | M3, K[0-3]: T3
495-
// +---+---+---+---+
496-
// M4 | | | | | M4, K[0-3]: T4
497-
// M5 | T | T | T | T | M5, K[0-3]: T5
498-
// M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
499-
// M7 | | | | | M7, K[0-3]: T7
500-
// +---+---+---+---+ ==>
501-
// M8 | | | | | M8, K[0-3]: T8
502-
// M9 | T | T | T | T | M9, K[0-3]: T9
503-
// M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
504-
// M11 | | | | | M11, K[0-3]: T11
505-
// +---+---+---+---+
506-
// M12 | | | | | M12, K[0-3]: T12
507-
// M13 | T | T | T | T | M13, K[0-3]: T13
508-
// M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
509-
// M15 | | | | | M15, K[0-3]: T15
510-
// +---+---+---+---+
511-
512-
// Given the layout represented by `enc` and shape, we can derive the layout
513-
// that ds_read_b64_tr need to have in order to perform a vectorized load of
514-
// the elements. This can be done by rearranging the inner 4x16 element base
515-
// unit in the LL by rearranging the first numReg register bases and the
516-
// first numLane lane bases.
517-
auto rotatePrefixes = [](BaseTy &regBase, std::size_t numReg,
518-
BaseTy &laneBase, std::size_t numLane) {
519-
// Concatenate prefixes of the two vectors. Lane first and then regs.
520-
// C D E F | A B
521-
// Then copy over numReg to the regBase and numLane to laneBase
522-
// C D | E F A B
523-
BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane);
524-
llvm::append_range(
525-
baseUnit, llvm::make_range(regBase.begin(), regBase.begin() + numReg));
526-
527-
std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin());
528-
std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin());
529-
};
530-
531-
auto ctx = enc.getContext();
532-
assert(elemBitWidth == 8 || elemBitWidth == 16);
533-
// Get how many reg bases and tile bases the ds_read_tr tile spans
534-
unsigned numRegBases = llvm::Log2_32(instBitWidth / elemBitWidth);
535-
unsigned numLaneBases = llvm::Log2_32(numLanesInShuffleGroup);
536-
537-
auto ldsTransLayout = triton::gpu::toLinearLayout(shape, enc);
538-
auto bases = ldsTransLayout.getBases();
539-
auto kRegister = S("register");
540-
auto kLane = S("lane");
541-
542-
// Make sure that we have enough register bases to rotate, otherwise we
543-
// can't return a valid ds_read_tr layout
544-
if (ldsTransLayout.getInDimSizeLog2(kRegister) < numRegBases) {
545-
return std::nullopt;
546-
}
547-
// We should always have enough lanes
548-
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= numLaneBases);
549-
rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases);
550-
// Scale types double the elements for a total of 16 vgpr (still only 16
551-
// elements contiguous). Need to adjust the lane basis to reflect that
552-
if (elemBitWidth == 8 && numLanesInShuffleGroup == 8) {
553-
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= (numLaneBases + 1));
554-
std::swap(bases[kLane][numLaneBases - 1], bases[kLane][numLaneBases]);
555-
}
556-
557-
return LinearLayout(bases, ldsTransLayout.getOutDims(), false);
558-
}
559-
560473
std::optional<LinearLayout>
561474
chooseDotDsReadTrLayout(DotOperandEncodingAttr dotMfmaLayout,
562475
ArrayRef<int64_t> shape, int32_t elemBitWidth,
@@ -1192,20 +1105,39 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11921105
LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]);
11931106
auto newEncoding = TensorMemoryEncodingAttr::get(
11941107
ctx, encoding.getBlockM(), encoding.getBlockN(),
1195-
encoding.getColStride(), encoding.getCTASplitM(), 1);
1108+
encoding.getColStride(), encoding.getCTASplitM(), 1,
1109+
encoding.getTwoCTAs());
11961110
return tensorMemoryToLinearLayout(
11971111
{shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) *
11981112
split;
11991113
}
12001114
if (encoding.getCTASplitM() > 1) {
1201-
auto split =
1202-
LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]);
1115+
auto splitM = encoding.getCTASplitM();
1116+
auto blockM = encoding.getBlockM();
1117+
bool isM64TwoCTA = blockM == 64 && encoding.getTwoCTAs();
1118+
if (isM64TwoCTA) {
1119+
// blockM == 64 and twoCTAs is laid out as the transpose of 128xblockN
1120+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-b
1121+
blockM *= 2;
1122+
splitM /= 2;
1123+
}
1124+
auto split = LinearLayout::identity1D(splitM, kCol, dims[0]);
12031125
auto newEncoding = TensorMemoryEncodingAttr::get(
1204-
ctx, encoding.getBlockM(), encoding.getBlockN(),
1205-
encoding.getColStride(), 1, encoding.getCTASplitN());
1206-
return tensorMemoryToLinearLayout(
1207-
{shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) *
1208-
split;
1126+
ctx, blockM, encoding.getBlockN(), encoding.getColStride(), 1,
1127+
encoding.getCTASplitN(), encoding.getTwoCTAs());
1128+
auto ret =
1129+
tensorMemoryToLinearLayout({shape[0] / splitM, shape[1]}, newEncoding) *
1130+
split;
1131+
// In this case, we swap the basis of the last row and last column as per
1132+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-bny
1133+
if (isM64TwoCTA) {
1134+
auto bases = ret.getBases();
1135+
auto &rowBases = bases[kRow];
1136+
auto &colBases = bases[kCol];
1137+
std::swap(rowBases[rowBases.size() - 1], colBases[colBases.size() - 1]);
1138+
ret = LinearLayout(bases, ret.getOutDims(), ret.isSurjective());
1139+
}
1140+
return ret;
12091141
}
12101142
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1);
12111143

@@ -1461,14 +1393,10 @@ std::optional<LinearLayout>
14611393
chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14621394
int32_t elemBitWidth, unsigned instBitWidth,
14631395
unsigned numLanesInShuffleGroup) {
1464-
if (elemBitWidth == 4) {
1465-
auto dot = cast<DotOperandEncodingAttr>(enc);
1466-
return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth,
1467-
numLanesInShuffleGroup);
1468-
} else {
1469-
return chooseLLDsReadTrLayout(enc, shape, elemBitWidth, instBitWidth,
1470-
numLanesInShuffleGroup);
1471-
}
1396+
assert(elemBitWidth == 4);
1397+
auto dot = cast<DotOperandEncodingAttr>(enc);
1398+
return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth,
1399+
numLanesInShuffleGroup);
14721400
}
14731401

14741402
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
143143
<< ll.getOutDimSize(dims[0]) << "x"
144144
<< ll.getOutDimSize(dims[1]);
145145
}
146+
// Note the following holds for both M=64 and M=128 with 2CTA
147+
auto nCol = ll.getInDimSize(StringAttr::get(ctx, "col"));
148+
if (nCol / (enc.getCTASplitM() * enc.getCTASplitN()) >
149+
512 * 32 / bitwidth) {
150+
return emitError() << "nCol / (CTASplitM * CTASplitN) must be less than "
151+
"or equal to 512 * 32 / bitwidth but got "
152+
<< nCol / (enc.getCTASplitM() * enc.getCTASplitN());
153+
}
146154
} else if (auto enc = dyn_cast<SharedEncodingTrait>(encoding)) {
147155
if (memorySpace != SharedMemorySpaceAttr::get(ctx)) {
148156
return emitError()

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
566566
unsigned colStride = 32 / bitwidth;
567567
Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
568568
context, instrShape[0], instrShape[1], colStride, CTASplitNum[0],
569-
CTASplitNum[1]);
569+
CTASplitNum[1], useTwoCTAs);
570570
Attribute tensorMemorySpace =
571571
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
572572
MemDescType accMemDescType =
@@ -847,7 +847,7 @@ class ScaledBlockedToMMAv5
847847
auto bitwidth = oldRetType.getElementType().getIntOrFloatBitWidth();
848848
unsigned colStride = 32 / bitwidth;
849849
Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
850-
context, m, n, colStride, CTASplitNum[0], CTASplitNum[1]);
850+
context, m, n, colStride, CTASplitNum[0], CTASplitNum[1], false);
851851
Attribute tensorMemorySpace =
852852
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
853853
MemDescType accMemDescType =

0 commit comments

Comments
 (0)