Skip to content

Commit d5d1c6d

Browse files
[ConSan] Add support for WarpSpecialization (#8189)
Introducing the support for warp specialized kernels in ConSan. This PR changes replaces some of the ConSan auxiliary data structures and augments others to enable tracking of read/write visibility of the buffers between different "threads" (partitions). It introduces concepts of TC and TMA logical threads, to model HW guarantees of intra and inter-thread operation ordering, and verifying that these guarantees are not being violated by the program.
1 parent 55613a7 commit d5d1c6d

File tree

14 files changed

+3394
-1550
lines changed

14 files changed

+3394
-1550
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ docs/sg_execution_times.rst
8585
/compile_commands.json
8686
.vscode
8787
.vs
88+
.cursor
8889

8990
# Vim
9091
*.swp

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
4848

4949
// Find the contextual number of warps on which this operation is executed.
5050
int lookupNumWarps(Operation *op);
51+
int lookupNumWarps(Region *region);
5152
// Try to find the contextual number of warps on which this operation is
5253
// executed. Returns nullopt if a warp size cannot be find. This is used for
5354
// verifiers.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Triton Instrument Dialect and Concurrency Sanitizer (ConSan)
2+
3+
### Overview
4+
5+
ConSan instruments Triton IR to detect illegal concurrent accesses to shared and Tensor Core memory under warp specialization. It tracks per-buffer visibility of reads and writes across threads, models barrier-based synchronization, and models commit-count–based synchronization (cp.async, wgmma).
6+
7+
Auxiliary state is kept in distributed tensors and global scratch memory, with types created on-demand per warp-specialization partition.
8+
9+
### Thread model
10+
11+
- Base threads: 16 warp-specialization (WS) threads (allowing for up to 16 partitions).
12+
- Peer classes: +16 Tensor Core (TC) threads and +16 TMA threads to model lack of ordering with base threads.
13+
- Total logical threads: 48. Bitmasks are sized to the next power of two: 64.
14+
15+
Indexing uses a logical thread id in [0, 48), with column vectors sized to 64 for layout convenience.
16+
17+
## Auxiliary data structures
18+
19+
All types are generated on-demand (per partition) based on:
20+
21+
- B: number of tracked buffers (power-of-two padded)
22+
- K: number of mbarriers (power-of-two padded)
23+
- T_bits: 64 (bitmask width)
24+
- T_commits: 16 (base threads; commit counters do not apply to TC/TMA helpers)
25+
26+
“tensor” means a distributed Triton tensor; “scratch” means a pointer into global scratch memory. Shapes below are logical; actual encodings are partition-local blocked layouts.
27+
28+
- buffers (tensor, <B x i64>): Base pointers of all (sub)buffers per memory space
29+
- barriers (tensor, <K x i64>): Pointers of all mbarriers
30+
- writeVisibility (scratch, <B x i64>): Per-buffer bitmask. Bit i set ⇒ thread i can see latest completed write to that buffer
31+
- readVisibility (scratch, <B x 64 x i64>): Per-buffer, per-thread lanes. Each lane stores a 64-bit mask of other threads whose reads are visible to that lane’s thread
32+
- writeTracking (scratch, <B x K x i8>): Map buffers → barriers tracking writes (boolean stored in i8)
33+
- readTracking (scratch, <B x K x i64>): Map buffers → barriers tracking reads (bitmask of threads)
34+
- outstandingCommits (scratch, <B x 16 x i8>): Per-buffer, per-base-thread commit counters for cp.async and wgmma
35+
36+
## Visibility and legality rules
37+
38+
- Reads are legal iff the reading thread sees the most recent write to the buffer (writeVisibility). There can be only one write in-flight.
39+
- Writes are legal iff the writing thread sees both all prior writes and all reads completed for that buffer.
40+
41+
ConSan enforces these via two checks emitted before memory ops:
42+
43+
- experimental_verify_write_visibility: “no one else is writing, or I can see the write”
44+
- experimental_verify_read_visibility: “my read-visibility lane is a superset of the OR of all lanes”
45+
46+
## Barrier-based synchronization
47+
48+
ConSan separates “tracking” from “visibility transfer”:
49+
50+
- At memory ops that are tracked by a barrier (loads/stores, some TMEM ops):
51+
- experimental_set_read_visibility / experimental_set_write_visibility updates the appropriate visibility table for the current thread and buffer.
52+
- experimental_track_visible_reads / experimental_track_visible_writes snapshots current per-buffer visibility into readTracking/writeTracking for the given barrier.
53+
- At arrive/commit sites (e.g., tc commit, arrive on mbarrier): ConSan emits the track ops for both reads and writes.
54+
- At waits: experimental_transfer_visible_reads / experimental_transfer_visible_writes propagates tracked visibility from the barrier back into the waiting thread’s visibility, and this transfer is repeated to peer threads (base, TMA, TC) to keep the three classes consistent.
55+
56+
## Commit-count–based synchronization
57+
58+
Some hardware ops synchronize via “number of outstanding commits” rather than mbarriers.
59+
60+
- Stage: experimental_stage_access_for_commit marks the current thread’s buffer lane with -1 (staged) in outstandingCommits[B x 16].
61+
- Commit: experimental_commit_accesses turns -1 into 1 and increments positive entries for the committing thread column.
62+
- Wait (cp.async): experimental_clear_outstanding_commits_set_write(thread, commits, writeVisibility, N) clears entries with count > N for the current thread, and sets the writeVisibility bit for rows where any thread’s entry was cleared.
63+
- Wait (wgmma): experimental_clear_outstanding_commits_set_read(thread, commits, readVisibility, N) clears entries with count > N for the current thread, and sets the readVisibility bit for rows where any thread’s entry was cleared.
64+
65+
Legality checks for commit-count flows:
66+
67+
- For writes to shared memory affected by cp.async: experimental_check_outstanding_commits(buffer, commits, "async_copy_global_to_shared") asserts the row for the buffer is all zeros (no pending writes), across all base-thread columns.
68+
- For reads of wgmma operands in shared memory: experimental_check_outstanding_commits(buffer, commits, "warpgroup_mma operand read") asserts the row is all zeros (no pending reads).
69+
70+
Note: The check op has no “thread” operand; it inspects the whole row for the buffer.

0 commit comments

Comments
 (0)