Skip to content

Commit 86d9987

Browse files
Merge OpenAI Triton commit 6e390f3 (#5216)
This PR change the Triton base from dbc85fc to 6e390f3 (Sep 23). Pass rate: 96.23%->96.98%
2 parents e9d399a + 39029bf commit 86d9987

File tree

92 files changed

+4801
-2193
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+4801
-2193
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ docs/sg_execution_times.rst
9393
/compile_commands.json
9494
.vscode
9595
.vs
96+
.cursor
9697

9798
# Vim
9899
*.swp

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
244244
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
245245
- `TRITON_DISABLE_LINE_INFO=1` removes all line information from the module.
246246
- `PTXAS_OPTIONS` passes additional command-line options to the PTX assembler `ptxas` (only on NVIDIA).
247+
- `LLVM_EXTRACT_DI_LOCAL_VARIABLES` emit full debug info, allowing for eval of values in gpu debuggers (ie cuda-gdb, rocm-gdb etc)
247248

248249
> [!NOTE]
249250
> Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.

bin/RegisterTritonDialects.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
5050
#include "mlir/InitAllPasses.h"
5151

52+
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
53+
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
54+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
55+
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
56+
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
57+
5258
namespace mlir {
5359
namespace test {
5460
namespace intel {
@@ -108,13 +114,20 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
108114
mlir::triton::registerTritonGENToSPIRVPasses();
109115
mlir::LLVM::registerInlinerInterface(registry);
110116
mlir::NVVM::registerInlinerInterface(registry);
117+
mlir::registerLLVMDILocalVariable();
111118

112119
// TritonAMDGPUToLLVM passes
113120
mlir::triton::registerAllocateAMDGPUSharedMemory();
114121
mlir::triton::registerConvertTritonAMDGPUToLLVM();
115122
mlir::triton::registerConvertBuiltinFuncToLLVM();
116123
mlir::triton::registerOptimizeAMDLDSUsage();
117124

125+
mlir::ub::registerConvertUBToLLVMInterface(registry);
126+
mlir::registerConvertNVVMToLLVMInterface(registry);
127+
mlir::registerConvertMathToLLVMInterface(registry);
128+
mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
129+
mlir::arith::registerConvertArithToLLVMInterface(registry);
130+
118131
// TritonAMDGPUTransforms passes
119132
mlir::registerTritonAMDGPUAccelerateMatmul();
120133
mlir::registerTritonAMDGPUOptimizeEpilogue();

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
4848

4949
// Find the contextual number of warps on which this operation is executed.
5050
int lookupNumWarps(Operation *op);
51+
int lookupNumWarps(Region *region);
5152
// Try to find the contextual number of warps on which this operation is
5253
// executed. Returns nullopt if a warp size cannot be find. This is used for
5354
// verifiers.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,15 +1036,17 @@ An encoding for tensors that have been produced by MFMA matrix core instructions
10361036
available on AMD Instinct GPUs of CDNA architectures.
10371037

10381038
It is characterized by the following parameters:
1039-
- `version` indicates the GPU architecture:
1039+
- `version`: The GPU architecture:
10401040
- 1: gfx908: CDNA1
10411041
- 2: gfx90a: CDNA2
10421042
- 3: gfx942: CDNA3
10431043
- 4: gfx950: CDNA4
1044-
- `warpsPerCTA` indicates the warp layout in the block.
1045-
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
1046-
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
1044+
- `warpsPerCTA`: The warp layout in the block.
1045+
- `instrShape`: The shape in the form of (M, N, K) of the matrix.
1046+
- `isTransposed`: Indicates the result tensor is transposed so that it can be converted to dotOperand layout
10471047
without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel).
1048+
- `tilesPerWarp`: The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions.
1049+
- `elementBitWidth`: Bit width of the output element type. Supported values are 32 and 64. Defaults to 32.
10481050

10491051
Example 1:
10501052
Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32.
@@ -1154,25 +1156,27 @@ w2 w2 w3 w3
11541156
ins
11551157
"unsigned": $version,
11561158
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1157-
ArrayRefParameter<"unsigned">:$tilesPerWarp,
1158-
"unsigned":$MDim,
1159-
"unsigned":$NDim,
1159+
ArrayRefParameter<"unsigned">:$instrShape,
11601160
"bool":$isTransposed,
11611161
"CTALayoutAttr":$CTALayout,
1162-
DefaultValuedParameter<"std::optional<Type>", "FloatType::get($_ctxt, 32)">:$elementType
1162+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
1163+
"unsigned":$elementBitWidth
11631164
);
11641165

11651166
let builders = [
11661167
AttrBuilder<(ins "unsigned":$version,
11671168
"ArrayRef<unsigned>":$warpsPerCTA,
1168-
"unsigned":$MDim,
1169-
"unsigned":$NDim,
1169+
"ArrayRef<unsigned>":$instrShape,
11701170
"bool":$isTransposed,
11711171
"CTALayoutAttr":$CTALayout,
1172-
"std::optional<Type>":$elementType), [{
1173-
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1174-
1175-
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout, elementType);
1172+
CArg<"ArrayRef<unsigned>", "{}">:$tpw,
1173+
CArg<"unsigned", "0">:$elementBitWidth), [{
1174+
SmallVector<unsigned> tilesPerWarp(tpw);
1175+
if (tilesPerWarp.empty())
1176+
tilesPerWarp = SmallVector<unsigned>(warpsPerCTA.size(), 1);
1177+
if (elementBitWidth == 0)
1178+
elementBitWidth = 32;
1179+
return $_get($_ctxt, version, warpsPerCTA, instrShape, isTransposed, CTALayout, tilesPerWarp, elementBitWidth);
11761180
}]>
11771181
];
11781182

@@ -1194,6 +1198,7 @@ w2 w2 w3 w3
11941198

11951199
let genVerifyDecl = 1;
11961200
let hasCustomAssemblyFormat = 1;
1201+
let skipDefaultBuilders = 1;
11971202
}
11981203

11991204
def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> {

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class MMAv5PipelineableOperandsHelper {
5353
bool isOperandPipelineable(Value v, Operation *&foundDef);
5454
};
5555

56+
bool areScalesPipelineable(TCGen5MMAScaledOp scaledOp, scf::ForOp forOp);
57+
bool isOperandPipelineableBase(
58+
Value v, scf::ForOp forOp, Operation *&foundDef,
59+
std::function<bool(Operation *)> isPipelineable =
60+
[](Operation *) { return false; },
61+
std::function<bool(Operation *)> isLoadToBePipelined =
62+
[](Operation *) { return false; });
63+
5664
//===----------------------------------------------------------------------===//
5765
// MMA Pipeline Rewriters
5866
//===----------------------------------------------------------------------===//
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Triton Instrument Dialect and Concurrency Sanitizer (ConSan)
2+
3+
### Overview
4+
5+
ConSan instruments Triton IR to detect illegal concurrent accesses to shared and Tensor Core memory under warp specialization. It tracks per-buffer visibility of reads and writes across threads, models barrier-based synchronization, and models commit-count–based synchronization (cp.async, wgmma).
6+
7+
Auxiliary state is kept in distributed tensors and global scratch memory, with types created on-demand per warp-specialization partition.
8+
9+
### Thread model
10+
11+
- Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions).
12+
- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads.
13+
- Total logical threads: 48. Bitmasks are sized to the next power of two: 64.
14+
15+
Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience.
16+
17+
## Auxiliary data structures
18+
19+
All types are generated on-demand (per partition) based on:
20+
21+
- B: number of tracked buffers (power-of-two padded)
22+
- K: number of mbarriers (power-of-two padded)
23+
- T_bits: 64 (bitmask width)
24+
- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers)
25+
26+
“tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts.
27+
28+
- buffers (tensor, <B x i64>): Base pointers of all (sub)buffers per memory space
29+
- barriers (tensor, <K x i64>): Pointers of all mbarriers
30+
- writeVisibility (scratch, <B x i64>): Per-buffer bitmask. Bit i set ⇒ thread i can see latest completed write to that buffer
31+
- readVisibility (scratch, <B x 64 x i64>): Per-buffer, per-thread lanes. Each lane stores a 64-bit mask of other threads whose reads are visible to that lane’s thread
32+
- writeTracking (scratch, <B x K x i8>): Map buffers → barriers tracking writes (boolean stored in i8)
33+
- readTracking (scratch, <B x K x i64>): Map buffers → barriers tracking reads (bitmask of threads)
34+
- outstandingCommits (scratch, <B x 16 x i8>): Per-buffer, per-base-thread commit counters for cp.async and wgmma
35+
36+
## Visibility and legality rules
37+
38+
- Reads are legal iff the reading thread sees the most recent write to the buffer (writeVisibility). There can be only one write in-flight.
39+
- Writes are legal iff the writing thread sees both all prior writes and all reads completed for that buffer.
40+
41+
ConSan enforces these via two checks emitted before memory ops:
42+
43+
- experimental_verify_write_visibility: “no one else is writing, or I can see the write”
44+
- experimental_verify_read_visibility: “my read-visibility lane is a superset of the OR of all lanes”
45+
46+
## Barrier-based synchronization
47+
48+
ConSan separates “tracking” from “visibility transfer”:
49+
50+
- At memory ops that are tracked by a barrier (loads/stores, some TMEM ops):
51+
- experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer.
52+
- experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier.
53+
- At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes.
54+
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent.
55+
56+
## Commit-count–based synchronization
57+
58+
Some hardware ops synchronize via “number of outstanding commits” rather than mbarriers.
59+
60+
- Stage: experimental_stage_access_for_commit marks the current thread’s buffer lane with -1 (staged) in outstandingCommits[B x 16].
61+
- Commit: experimental_commit_accesses turns -1 into 1 and increments positive entries for the committing thread column.
62+
- Wait (cp.async): experimental_clear_outstanding_commits_set_write(thread, commits, writeVisibility, N) clears entries with count > N for the current thread, and sets the writeVisibility bit for rows where any thread’s entry was cleared.
63+
- Wait (wgmma): experimental_clear_outstanding_commits_set_read(thread, commits, readVisibility, N) clears entries with count > N for the current thread, and sets the readVisibility bit for rows where any thread’s entry was cleared.
64+
65+
Legality checks for commit-count flows:
66+
67+
- For writes to shared memory affected by cp.async: experimental_check_outstanding_commits(buffer, commits, "async_copy_global_to_shared") asserts the row for the buffer is all zeros (no pending writes), across all base-thread columns.
68+
- For reads of wgmma operands in shared memory: experimental_check_outstanding_commits(buffer, commits, "warpgroup_mma operand read") asserts the row is all zeros (no pending reads).
69+
70+
Note: The check op has no “thread” operand; it inspects the whole row for the buffer.

0 commit comments

Comments
 (0)