Skip to content

Commit 0ca3ce3

Browse files
Merge OpenAI Triton commit 4dfdc32 (#4445)
This PR change the Triton base from 6af4919 to 4dfdc32 (Jun 5). Pass rate: 97.23%
2 parents e675298 + b669c0d commit 0ca3ce3

File tree

114 files changed

+2175
-1299
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+2175
-1299
lines changed

.github/workflows/documentation.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ on:
44
schedule:
55
- cron: "0 0 * * *"
66

7-
permissions: read-all
7+
permissions:
8+
contents: write
89

910
jobs:
1011
Build-Documentation:
@@ -15,7 +16,7 @@ jobs:
1516
- name: Checkout branch
1617
uses: actions/checkout@v4
1718
with:
18-
token: ${{ secrets.CI_PAT }}
19+
token: ${{ secrets.GITHUB_TOKEN }}
1920
fetch-depth: 0
2021

2122
- name: Clear docs

.github/workflows/integration-tests-amd.yml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,6 @@ jobs:
6060
~/.triton/nvidia
6161
~/.triton/json
6262
key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }}
63-
- # Cache ~/.cache/ccache to speed up compilation.
64-
#
65-
# On branch `main` we always start from an empty cache, i.e. we skip the
66-
# "restore" step. This is to prevent the caches from accumulating stale
67-
# files over time.
68-
name: Restore cache of ccache and Triton compilation artifacts
69-
id: restore-build-cache
70-
if: github.ref != 'refs/heads/main'
71-
uses: actions/cache/restore@v4
72-
with:
73-
path: |
74-
~/.ccache
75-
# Restore the most recent cache entry.
76-
restore-keys: |
77-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-
78-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-
79-
# We expect this cache key never to hit and for us to fall back
80-
# unconditionally to the restore-key, so it doesn't actually matter
81-
# what we put here (so long as it doesn't hit an existing key).
82-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
8363
- name: Inspect cache directories
8464
run: |
8565
mkdir -p ~/.triton
@@ -152,18 +132,6 @@ jobs:
152132
153133
mkdir -p ~/.ccache
154134
du -h -d 1 ~/.ccache
155-
- # If we're on branch `main`, save the ccache Triton compilation artifacts
156-
# to the cache so they can be used by other (non-main) CI runs.
157-
#
158-
# (It wouldn't be a problem to save the cache on every run, because github
159-
# evicts cache entries LRU, but maybe this saves a bit of time in CI.)
160-
name: Save ccache and Triton compilation artifacts to cache
161-
if: github.ref == 'refs/heads/main'
162-
uses: actions/cache/save@v4
163-
with:
164-
path: |
165-
~/.ccache
166-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
167135
- name: Clean up caches
168136
# Always cleanup the worker, even if builds or tests failed
169137
if: always()

.github/workflows/integration-tests-nvidia.yml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,6 @@ jobs:
5757
~/.triton/nvidia
5858
~/.triton/json
5959
key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }}
60-
- # Cache ~/.cache/ccache to speed up compilation.
61-
#
62-
# On branch `main` we always start from an empty cache, i.e. we skip the
63-
# "restore" step. This is to prevent the caches from accumulating stale
64-
# files over time.
65-
name: Restore cache of ccache and Triton compilation artifacts
66-
id: restore-build-cache
67-
if: github.ref != 'refs/heads/main'
68-
uses: actions/cache/restore@v4
69-
with:
70-
path: |
71-
~/.ccache
72-
# Restore the most recent cache entry.
73-
restore-keys: |
74-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-
75-
triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-
76-
# We expect this cache key never to hit and for us to fall back
77-
# unconditionally to the restore-key, so it doesn't actually matter
78-
# what we put here (so long as it doesn't hit an existing key).
79-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
8060
- name: Inspect cache directories
8161
run: |
8262
mkdir -p ~/.triton
@@ -130,15 +110,3 @@ jobs:
130110
131111
mkdir -p ~/.ccache
132112
du -h -d 1 ~/.ccache
133-
- # If we're on branch `main`, save the ccache Triton compilation artifacts
134-
# to the cache so they can be used by other (non-main) CI runs.
135-
#
136-
# (It wouldn't be a problem to save the cache on every run, because github
137-
# evicts cache entries LRU, but maybe this saves a bit of time in CI.)
138-
name: Save ccache and Triton compilation artifacts to cache
139-
if: github.ref == 'refs/heads/main'
140-
uses: actions/cache/save@v4
141-
with:
142-
path: |
143-
~/.ccache
144-
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ dev-install-llvm:
106106

107107
.PHONY: golden-samples
108108
golden-samples: triton-opt
109-
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
109+
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \
110110
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
111111
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
112-
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-pipeline -canonicalize | \
112+
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \
113113
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
114114
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -865,19 +865,24 @@ def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
865865
//
866866
// Histogram Op
867867
//
868-
def TT_HistogramOp : TT_Op<"histogram", [Pure]> {
868+
def TT_HistogramOp : TT_Op<"histogram", [Pure,
869+
TypesMatchWith<"mask type matches src type",
870+
"src", "mask", "getI1SameShape($_self)",
871+
"($_op.getOperands().size() <= 1) || std::equal_to<>()">]> {
869872
let summary = "return a histogram of the inputs.";
870873
let description = [{
871874
Return the histogram of the input tensor. The number of bins is equal to
872875
the dimension of the output tensor. Each bins has a width of 1 and bins
873876
start at 0.
874877
}];
875878

876-
let arguments = (ins TT_IntTensor:$src);
879+
let arguments = (ins TT_IntTensor:$src,
880+
Optional<TT_BoolLike>:$mask);
881+
877882
let results = (outs TT_IntTensor:$result);
878883

879884
let assemblyFormat = [{
880-
$src attr-dict `:` type($src) `->` type($result)
885+
$src (`,` $mask^)? attr-dict `:` type($src) `->` type($result)
881886
}];
882887
}
883888

@@ -1028,22 +1033,6 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10281033
}];
10291034
}
10301035

1031-
def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
1032-
let summary = "Reinterpret a pointer as a tensor descriptor";
1033-
1034-
let description = [{
1035-
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
1036-
Ideally, we can remove this once the APIs are fully fleshed out.
1037-
}];
1038-
1039-
let arguments = (ins TT_Ptr:$rawDesc);
1040-
let results = (outs TT_TensorDescType:$result);
1041-
1042-
let assemblyFormat = [{
1043-
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
1044-
}];
1045-
}
1046-
10471036
// The following ops, including `call`, `func`, and `return` are copied and modified from
10481037
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
10491038
// We could revert it back once MLIR has a better inliner interface.
@@ -1385,54 +1374,5 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLike
13851374
let hasVerifier = 1;
13861375
}
13871376

1388-
def TT_ExperimentalTensormapCreateOp: TT_Op<
1389-
"experimental_tensormap_create",
1390-
[
1391-
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1392-
AttrSizedOperandSegments,
1393-
]
1394-
> {
1395-
let summary = "Create a new TMA descriptor on device";
1396-
let arguments = (
1397-
ins
1398-
TT_PtrType:$desc_ptr,
1399-
TT_PtrType:$global_address,
1400-
Variadic<I32>:$box_dim,
1401-
Variadic<I32>:$global_dim,
1402-
Variadic<I64>:$global_stride,
1403-
Variadic<I32>:$element_stride,
1404-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<15>]>:$elem_type,
1405-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
1406-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
1407-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
1408-
);
1409-
let extraClassDeclaration = [{
1410-
int32_t getRank() {
1411-
return getBoxDim().size();
1412-
}
1413-
}];
1414-
let assemblyFormat = [{
1415-
$desc_ptr `,` $global_address `,`
1416-
`[` $box_dim `]` `,`
1417-
`[` $global_dim `]` `,`
1418-
`[` $global_stride `]` `,`
1419-
`[` $element_stride `]`
1420-
attr-dict `:` functional-type(operands, results)
1421-
}];
1422-
1423-
let hasVerifier = 1;
1424-
}
1425-
1426-
def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op<
1427-
"experimental_tensormap_fenceproxy_acquire",
1428-
[MemoryEffects<[MemWrite<GlobalMemory>]>]
1429-
> {
1430-
let summary = "Acquire fence on a tensormap object";
1431-
let arguments = (ins TT_PtrType:$desc_ptr);
1432-
let assemblyFormat = [{
1433-
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
1434-
}];
1435-
}
1436-
14371377

14381378
#endif // Triton_OPS

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
9292
// Any Type in Triton IR
9393
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;
9494

95-
// Result type of ExperimentalMakeTensorDescriptor
95+
// Result type of MakeTensorDescriptor
9696
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
9797
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";
9898

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ LinearLayout chooseScaledMfmaScaleLayout(
287287
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
288288
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
289289

290+
// Create LinearLayout for nvidia mma tile.
291+
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
292+
unsigned kWidth, ArrayRef<unsigned> order,
293+
ArrayRef<unsigned> repOrder);
294+
290295
// Create a LinearLayout similar to mfmaLayout, but changing each thread to hold
291296
// 8 elements. This layout is useful for emitting the widest 128-bit global
292297
// store instructions. Since it closely resembles mfmaLayout, conversion between

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -312,54 +312,46 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
312312
if(!mmaEnc)
313313
return get(context, 1, 1, 1, order, CTALayout);
314314

315-
int opIdx = dotOpEnc.getOpIdx();
316-
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
317-
318-
// number of rows per phase
319-
320-
// index of the inner dimension in `order`
321-
unsigned inner = (opIdx == 0) ? 0 : 1;
322-
323315
// ---- begin Ampere & Hopper ----
324316
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
325-
int perPhase = 128 / (std::max<int>(1, shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()));
326-
perPhase = std::max<int>(perPhase, 1);
327-
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328-
int vecWidth = 32 / typeWidthInBit;
329-
if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) {
330-
perPhase = std::max<int>(perPhase, 2 * vecWidth);
331-
}
332-
int rank = order.size();
333-
// --- handle A operand ---
334-
if (opIdx == 0) { // compute swizzling for A operand
335-
int m = (needTrans) ? matShape[2] : matShape[0];
336-
int k = (needTrans) ? matShape[0] : matShape[2];
337-
int vec = (order[0] == rank-1) ? k : m;
338-
int mmaStride = (order[0] == rank-1) ? m : k;
339-
int maxPhase = std::max(mmaStride / perPhase, 1);
340-
return get(context, vec, perPhase, maxPhase, order, CTALayout);
341-
}
342-
343-
// --- handle B operand ---
344-
if (opIdx == 1) {
345-
// we compute vec and maxPhase m, n and k size of the mma
346-
// instruction. when matmul operands is transposed, we should
347-
// consider that to get m, n and k.
348-
int n = needTrans ? matShape[2] : matShape[1];
349-
int k = needTrans ? matShape[1] : matShape[2];
350-
int vec = (order[0] == rank-1) ? n : k;
351-
int mmaStride = (order[0] == rank-1) ? k : n;
352-
int maxPhase = std::max(mmaStride / perPhase, 1);
353-
return get(context, vec, perPhase, maxPhase, order, CTALayout);
354-
}
355-
356-
llvm_unreachable("invalid operand index");
317+
return get(context, dotOpEnc.getOpIdx(), dotOpEnc.getKWidth(), shape, order, CTALayout, typeWidthInBit, needTrans);
357318
}
358319

359320
// ---- not implemented ----
360321
llvm_unreachable("unsupported swizzling for provided MMA version");
361322
}]>,
362323

324+
// NVIDIA constructor!
325+
// TODO(lezcano): We should totally get rid of all these constructors...
326+
AttrBuilder<(ins "int":$opIdx,
327+
"unsigned":$kWidth,
328+
"ArrayRef<int64_t>":$shape,
329+
"ArrayRef<unsigned>":$order,
330+
"CTALayoutAttr":$CTALayout,
331+
"unsigned":$bitwidth,
332+
"bool":$needTrans), [{
333+
int K = getShapePerCTA(CTALayout.getCTASplitNum(), shape)[order[0]];
334+
// Elems necessary to cover all the banks divided by the inner dimension
335+
// This packs a few rows together for small K
336+
int perPhase = std::max<int>(1024 / (bitwidth * K), 1);
337+
338+
int mmaStride = 8;
339+
int vec = 4 * kWidth;
340+
// needsTrans is equiv. to flipping the opIdx
341+
if (needTrans)
342+
std::swap(vec, mmaStride);
343+
assert(opIdx == 0 || opIdx == 1);
344+
int rank = order.size();
345+
int kDim = opIdx == 0 ? rank-1 : rank-2;
346+
if (order[0] != kDim)
347+
std::swap(vec, mmaStride);
348+
// Count how many vec elements are needed to cover all the banks
349+
int maxPhase = std::max(std::min<int>(mmaStride, 1024 / (vec * bitwidth)), 1);
350+
// Account for the row packing from perPhase: mmaStride / perPhase
351+
maxPhase = std::max(maxPhase / perPhase, 1);
352+
return get(context, vec, perPhase, maxPhase, order, CTALayout);
353+
}]>,
354+
363355
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
364356
"ArrayRef<int64_t>":$shape,
365357
"ArrayRef<unsigned>":$order,

0 commit comments

Comments
 (0)