Skip to content

Commit b870f9b

Browse files
Merge OpenAI Triton commit 51021fb (#5238)
This PR change the Triton base from 22b1a44 to 51021fb (Sep 29). Pass rate: 92.67%->92.74%
2 parents 9ab8556 + a02513c commit b870f9b

File tree

47 files changed

+2105
-740
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2105
-740
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
134134
mlir::registerTritonAMDGPUHoistLayoutConversions();
135135
mlir::registerTritonAMDGPUReorderInstructions();
136136
mlir::registerTritonAMDGPUBlockPingpong();
137-
mlir::registerTritonAMDGPUStreamPipeline();
137+
mlir::registerTritonAMDGPUPipeline();
138+
mlir::registerTritonAMDGPUScheduleLoops();
138139
mlir::registerTritonAMDGPUCanonicalizePointers();
139140
mlir::registerTritonAMDGPUConvertToBufferOps();
140141
mlir::registerTritonAMDGPUInThreadTranspose();

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
182182
CoarseSchedule &schedule,
183183
std::function<bool(Operation *)> filterUse = nullptr);
184184

185+
// Clean up attributes passing over schedules across stages in pipelining
186+
void removePipeliningAttributes(ModuleOp moduleOp);
185187
} // namespace triton
186188
} // namespace mlir
187189

include/triton/Dialect/TritonInstrument/IR/TritonInstrument.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ All types are generated on-demand (per partition) based on:
3131
- readVisibility (scratch, <B x 64 x i64>): Per-buffer, per-thread lanes. Each lane stores a 64-bit mask of other threads whose reads are visible to that lane’s thread
3232
- writeTracking (scratch, <B x K x i8>): Map buffers → barriers tracking writes (boolean stored in i8)
3333
- readTracking (scratch, <B x K x i64>): Map buffers → barriers tracking reads (bitmask of threads)
34+
- barrierStates (scratch, <K x i32>): Packed barrier metadata. Bit 0 stores the current phase, bits [1..8] the initial arrival count, bits [9..16] the current arrival count. The verifier checks underflow before updating, and flips the phase when the current count reaches zero.
35+
- waiting (scratch, <K x i32>): Per-barrier bitfield describing waiting threads. Each base thread gets two bits: bit (2 * thread + 0) is the waiting flag, bit (2 * thread + 1) stores the phase the thread is waiting on.
3436
- outstandingCommits (scratch, <B x 16 x i8>): Per-buffer, per-base-thread commit counters for cp.async and wgmma
3537

3638
## Visibility and legality rules
@@ -53,6 +55,20 @@ ConSan separates “tracking” from “visibility transfer”:
5355
- At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes.
5456
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent.
5557

58+
### Barrier phase/count tracking
59+
60+
- experimental_init_barrier_state(barrier, count, barrierStates) initializes the per-barrier state with phase = 0 and both initial/current arrival counts = `count`.
61+
- experimental_verify_barrier_arrive(barrier, count, barrierStates) checks that subtracting `count` from the current arrival count would not underflow. The codegen emits an assert if it would.
62+
- experimental_update_barrier_state(barrier, count, barrierStates) applies the arrive: subtracts `count`, flips the phase when the count reaches zero, and reloads the current count from the initial count.
63+
64+
### Deadlock detection
65+
66+
ConSan records which phase each thread is waiting on:
67+
68+
- experimental_set_waiting(barrier, baseThread, phase, barriers, waiting) sets the waiting flag for `baseThread` and stores the requested `phase`. The flag/phase bits share the waiting bitfield (two bits per base thread).
69+
- experimental_check_all_active_waiting(activeMask, barriers, waiting, barrierStates) filters waiting threads to those whose stored phase matches the current barrier phase. If all active threads are waiting on matching phases, it raises a deadlock assert.
70+
- experimental_clear_waiting(barrier, baseThread, barriers, waiting) clears the waiting bits for `baseThread`. Each wait clears its own state after the wait completes.
71+
5672
## Commit-count–based synchronization
5773

5874
Some hardware ops synchronize via “number of outstanding commits” rather than mbarriers.

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 171 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@ include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td"
1212
// ConSan keeps auxilary data requied for tracking memory accesses in tensors.
1313
// These tensors are stored as a distributed tensor or in global scratch memory.
1414
//
15-
// Name | Storage | Rank/Type | Description
16-
// ----------------|---------|-----------------|------------
17-
// buffers | tensor | <B x i64> | Base pointers of all (sub)buffers
18-
// barriers | tensor | <K x i64> | Pointers to all individual mbarriers
19-
// writeVisibility | scratch | <B x i64> | Per-buffer thread-visibility bitmask (bit i => thread i visible)
20-
// readVisibility | scratch | <B x T x i64> | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks)
21-
// writeTracking | scratch | <B x K x i8> | Map buffers -> barriers that track writes
22-
// readTracking | scratch | <B x K x i64> | Map buffers -> barriers that track reads
15+
// Name | Storage | Rank/Type | Description
16+
// ------------------|---------|-----------------|------------
17+
// buffers | tensor | <B x i64> | Base pointers of all (sub)buffers
18+
// barriers | tensor | <K x i64> | Pointers to all individual mbarriers
19+
// barrierStates | scratch | <K x i32> | Packed barrier phase (bit 0) and arrival counts (bits[1..8] init, [9..16] current)
20+
// waiting | scratch | <K x i32> | Two bits per thread: waiting flag bit (LSB), stored phase bit (bit 1)
21+
// writeVisibility | scratch | <B x i64> | Per-buffer thread-visibility bitmask (bit i => thread i visible)
22+
// readVisibility | scratch | <B x T x i64> | Per-buffer, per-thread visibility lanes (row-updated; values are bitmasks)
23+
// writeTracking | scratch | <B x K x i8> | Map buffers -> barriers that track writes
24+
// readTracking | scratch | <B x K x i64> | Map buffers -> barriers that track reads
2325
// outstandingCommits
24-
// (async/wgmma) | scratch | <B x T x i8> | Number of outstanding commits per buffer/thread (2D replaces prior 1D)
26+
// (async/wgmma) | scratch | <B x T x i8> | Number of outstanding commits per buffer/thread (2D replaces prior 1D)
2527

2628
//
2729
// Interfaces
@@ -62,6 +64,9 @@ def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [P
6264
}
6365

6466

67+
// ===== Critical section lock ops =====
68+
69+
6570
def TTI_ExperimentalLockAcquireOp : TTI_Op<"experimental_lock_acquire", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
6671
let summary = "Acquire a lock.";
6772
let description = [{
@@ -86,6 +91,9 @@ def TTI_ExperimentalLockReleaseOp : TTI_Op<"experimental_lock_release", [MemoryE
8691
}
8792

8893

94+
// ===== Barrier based synchronization ops =====
95+
96+
8997
def TTI_ExperimentalSetWriteVisibilityOp : TTI_Op<"experimental_set_write_visibility", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
9098
let summary = "set the write visibility for a buffer";
9199
let description = [{
@@ -315,6 +323,9 @@ def TTI_ExperimentalVerifyReadVisibilityOp : TTI_Op<"experimental_verify_read_vi
315323
}
316324

317325

326+
// ===== Commit-count–based synchronization ops =====
327+
328+
318329
def TTI_ExperimentalStageAccessForCommitOp : TTI_Op<"experimental_stage_access_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
319330
let summary = "";
320331
let description = [{
@@ -409,4 +420,155 @@ def TTI_ExperimentalCheckOutstandingCommitsOp : TTI_Op<"experimental_check_outst
409420
let hasVerifier = 1;
410421
}
411422

423+
424+
// ===== Barrier state ops =====
425+
426+
def TTI_ExperimentalInitBarrierStateOp : TTI_Op<"experimental_init_barrier_state", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
427+
let summary = "initialize the auxiliary barrier state";
428+
let description = [{
429+
Initialize the tracked barrier state to phase 0 and set both the initial and current arrival counts.
430+
}];
431+
let arguments = (ins
432+
TTG_MemDescType:$mbar,
433+
I32Attr:$count,
434+
TT_Tensor:$barriers,
435+
TT_PtrLike:$states,
436+
TypeAttr:$statesType
437+
);
438+
let assemblyFormat = [{
439+
$mbar `,` $count `{` $barriers `,` $states `(` $statesType `)` `}` attr-dict `:` type($mbar) `,` type($barriers) `,` type($states)
440+
}];
441+
}
442+
443+
def TTI_ExperimentalVerifyBarrierArriveOp : TTI_Op<"experimental_verify_barrier_arrive", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
444+
let summary = "verify an arrive count against the tracked barrier state";
445+
let description = [{
446+
Check that applying the arrive count would not drive the tracked current count negative. Triggers an assertion on failure.
447+
}];
448+
let arguments = (ins
449+
TTG_MemDescType:$mbar,
450+
I32Attr:$count,
451+
TT_Tensor:$barriers,
452+
TT_PtrLike:$states,
453+
TypeAttr:$statesType,
454+
Optional<I1>:$pred
455+
);
456+
let assemblyFormat = [{
457+
$mbar `,` $count `{` $barriers `,` $states `(` $statesType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($states)
458+
}];
459+
}
460+
461+
def TTI_ExperimentalUpdateBarrierStateOp : TTI_Op<"experimental_update_barrier_state", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
462+
let summary = "update the auxiliary barrier state after a verified arrive";
463+
let description = [{
464+
Apply an arrive count to the tracked barrier state, toggling the phase when the count reaches zero and reloading the current count from the initial count.
465+
}];
466+
let arguments = (ins
467+
TTG_MemDescType:$mbar,
468+
I32Attr:$count,
469+
TT_Tensor:$barriers,
470+
TT_PtrLike:$states,
471+
TypeAttr:$statesType,
472+
Optional<I1>:$pred
473+
);
474+
let assemblyFormat = [{
475+
$mbar `,` $count `{` $barriers `,` $states `(` $statesType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($states)
476+
}];
477+
}
478+
479+
// ===== Deadlock detection ops =====
480+
481+
def TTI_ExperimentalSetWaitingOp : TTI_Op<"experimental_set_waiting", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
482+
let summary = "Mark the base thread as waiting on the given barrier phase";
483+
let description = [{
484+
For the barrier row matching mbar, set the waiting flag for baseThread and record the barrier phase being waited on.
485+
}];
486+
let arguments = (ins
487+
TTG_MemDescType:$mbar,
488+
I32Attr:$baseThread,
489+
I32:$phase,
490+
TT_Tensor:$barriers,
491+
TT_PtrLike:$waiting,
492+
TypeAttr:$waitingType,
493+
Optional<I1>:$pred
494+
);
495+
let assemblyFormat = [{
496+
$mbar `,` $baseThread `,` $phase `{` $barriers `,` $waiting `(` $waitingType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($waiting)
497+
}];
498+
}
499+
500+
def TTI_ExperimentalCheckAllActiveWaitingOp : TTI_Op<"experimental_check_all_active_waiting", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
501+
let summary = "Assert that not all active threads are waiting on matching phases";
502+
let description = [{
503+
Filter waiting threads to those whose recorded phase matches the current barrier phase, OR-reduce across barriers, and assert that (waitingMask & activeMask) != activeMask.
504+
}];
505+
let arguments = (ins
506+
I32Attr:$activeMask,
507+
TT_Tensor:$barriers,
508+
TT_PtrLike:$waiting,
509+
TypeAttr:$waitingType,
510+
TT_PtrLike:$barrierStates,
511+
TypeAttr:$barrierStatesType,
512+
Optional<I1>:$pred
513+
);
514+
let assemblyFormat = [{
515+
$activeMask `,` $barriers `,` $waiting `(` $waitingType `)` `,` $barrierStates `(` $barrierStatesType `)` (`,` $pred^)? attr-dict `:` type($barriers) `,` type($waiting) `,` type($barrierStates)
516+
}];
517+
}
518+
519+
def TTI_ExperimentalClearWaitingOp : TTI_Op<"experimental_clear_waiting", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
520+
let summary = "Clear the waiting state for the given base thread";
521+
let description = [{
522+
For the barrier row matching mbar, clear both the waiting flag and stored phase for baseThread.
523+
}];
524+
let arguments = (ins
525+
TTG_MemDescType:$mbar,
526+
I32Attr:$baseThread,
527+
TT_Tensor:$barriers,
528+
TT_PtrLike:$waiting,
529+
TypeAttr:$waitingType,
530+
Optional<I1>:$pred
531+
);
532+
let assemblyFormat = [{
533+
$mbar `,` $baseThread `{` $barriers `,` $waiting `(` $waitingType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($waiting)
534+
}];
535+
}
536+
537+
538+
// ===== Visibility replication ops =====
539+
540+
def TTI_ExperimentalCopyWriteVisibilityOp : TTI_Op<"experimental_copy_write_visibility", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
541+
let summary = "replicate write visibility from one thread to others";
542+
let description = [{
543+
Copy the write-visibility bit of sourceThread to every thread listed in destMask for all buffers. Destination bits are overwritten.
544+
}];
545+
let arguments = (ins
546+
I32Attr:$sourceThread,
547+
I64Attr:$destMask,
548+
TT_PtrLike:$writeVisibility,
549+
TypeAttr:$writeVisibilityType,
550+
Optional<I1>:$pred
551+
);
552+
let assemblyFormat = [{
553+
$sourceThread `,` $destMask `{` $writeVisibility `(` $writeVisibilityType `)` `}` (`,` $pred^)? attr-dict `:` type($writeVisibility)
554+
}];
555+
}
556+
557+
def TTI_ExperimentalCopyReadVisibilityOp : TTI_Op<"experimental_copy_read_visibility", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
558+
let summary = "replicate read visibility rows from one thread to others";
559+
let description = [{
560+
Copy the read-visibility row of sourceThread to every thread listed in destMask for all buffers. Destination rows are overwritten.
561+
}];
562+
let arguments = (ins
563+
I32Attr:$sourceThread,
564+
I64Attr:$destMask,
565+
TT_PtrLike:$readVisibility,
566+
TypeAttr:$readVisibilityType,
567+
Optional<I1>:$pred
568+
);
569+
let assemblyFormat = [{
570+
$sourceThread `,` $destMask `{` $readVisibility `(` $readVisibilityType `)` `}` (`,` $pred^)? attr-dict `:` type($readVisibility)
571+
}];
572+
}
573+
412574
#endif // TRITONINSTRUMENT_OPS

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ Operation *createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc,
1919
Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor);
2020
TypedValue<RankedTensorType> createConstIntTensor(OpBuilder &builder,
2121
Location loc, int64_t val,
22-
RankedTensorType tensorType);
22+
RankedTensorType tensorType,
23+
bool isSigned = false);
2324
FuncOp getEntryPoint(ModuleOp module);
2425
gpu::DistributedEncodingTrait
2526
getSingleDimSliceEncoding(gpu::BlockedEncodingAttr encoding, int dim);
@@ -43,8 +44,11 @@ struct AuxDataMap {
4344
Region *getEnclosingParitionOrFunctionRegion(Operation *op);
4445
};
4546

47+
// Please see TritonInstrumentOps.td for more information on the auxiliary
48+
// data structures.
4649
RegionToValueMap buffers[numMemTypes];
4750
RegionToValueMap barriers;
51+
RegionToValueMap barrierStates;
4852

4953
RegionToValueMap writeVisibility[numMemTypes];
5054
RegionToValueMap writeTracking[numMemTypes];
@@ -53,6 +57,7 @@ struct AuxDataMap {
5357
RegionToValueMap asyncCpCommits;
5458
RegionToValueMap wgmmaCommits;
5559
RegionToValueMap lock;
60+
RegionToValueMap waiting;
5661

5762
void populateAndPassToWarpSpecialize(ModuleOp module);
5863

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3232
"TRITON_DEFAULT_FP_FUSION",
3333
"TRITON_DISABLE_LINE_INFO",
3434
"TRITON_ENABLE_LLVM_DEBUG",
35-
"TRITON_HIP_GLOBAL_PREFETCH",
36-
"TRITON_HIP_LOCAL_PREFETCH",
3735
"TRITON_HIP_USE_ASYNC_COPY",
3836
"TRITON_HIP_USE_BLOCK_PINGPONG",
3937
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",

0 commit comments

Comments
 (0)