Skip to content

Commit 239e17f

Browse files
[ConSan] Add support for Tensor Memory (#7562)
This PR covers three main changes to Concurrency Sanitizer: 1. Adding support for Tensor Memory - implementation between shared and tensor memory is 95% common 2. Adding support for writes tracked by multiple buffers and tcgen05_commit. Commit mechanism is generic and is modeled fairly close after HW - any 'outstanding' write (one that hasn't been waited on yet) is started to be tracked by a barrier when the commit is called. 3. Implicit HW pipelining of tcgen05 operations. For now modeled loosely, by creating a class of ops that are "hwPipelined", and these ops do not assert when they write to a mem being written by another "hwPipelined" op. This is in general overly optimistic, as hw pipelining is not commutative, and also dependent on the operands. We can decide if we need to make this check more robust. What's coming next: 1. Checks for wgmma (should be fairly easy since it is similar to cp_async) 2. Adding checks for the rest of shmem operations (just emit the check ops) 3. A bit of a refactor to cp_async and to llvm lowering to make it easier to write and maintain
1 parent 7bc948c commit 239e17f

File tree

12 files changed

+1133
-331
lines changed

12 files changed

+1133
-331
lines changed

include/triton/Dialect/TritonInstrument/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ add_mlir_doc(TritonInstrumentDialect TritonInstrumentDialect dialects/ -gen-dial
88
set(LLVM_TARGET_DEFINITIONS TritonInstrumentOps.td)
99
mlir_tablegen(Ops.h.inc -gen-op-decls)
1010
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
11+
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
12+
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
1113
add_mlir_doc(TritonInstrumentOps TritonInstrumentOps dialects/ -gen-op-doc)
1214

1315
add_public_tablegen_target(TritonInstrumentTableGen)

include/triton/Dialect/TritonInstrument/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "triton/Dialect/Triton/IR/Dialect.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77

8+
#include "triton/Dialect/TritonInstrument/IR/OpsEnums.h.inc"
9+
810
#define GET_OP_CLASSES
911
#include "triton/Dialect/TritonInstrument/IR/Dialect.h.inc"
1012
#include "triton/Dialect/TritonInstrument/IR/Ops.h.inc"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITONINSTRUMENT_ATTR_DEFS
2+
#define TRITONINSTRUMENT_ATTR_DEFS
3+
4+
include "mlir/IR/EnumAttr.td"
5+
6+
def TT_MemTypeAttr : I32EnumAttr<
7+
"MemType", "",
8+
[
9+
I32EnumAttrCase<"SHARED", 0, "shared">,
10+
I32EnumAttrCase<"TENSOR", 1, "tensor">,
11+
]> {
12+
let cppNamespace = "::mlir::triton::instrument";
13+
}
14+
15+
#endif // TRITONINSTRUMENT_ATTR_DEFS

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
66
include "triton/Dialect/Triton/IR/TritonTypes.td"
77
include "mlir/IR/OpBase.td"
88
include "mlir/Interfaces/SideEffectInterfaces.td"
9+
include "triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td"
910

1011
//
1112
// Interfaces
@@ -33,33 +34,43 @@ def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [
3334
}
3435

3536

36-
def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_pointers", [Pure]> {
37+
def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [Pure]> {
3738
let summary = "definte an array of pointers to shared memory buffers";
3839
let description = [{
3940
Create a tensor of pointers to shared memory buffers.
4041
}];
41-
let arguments = (ins DenseI32ArrayAttr:$offsets);
42+
let arguments = (ins DenseI32ArrayAttr:$offsets, TT_MemTypeAttr:$memType);
4243
let results = (outs TT_Tensor:$result);
4344
let assemblyFormat = [{
44-
attr-dict `:` type($result)
45+
$offsets `,` $memType attr-dict `:` type($result)
4546
}];
4647
}
4748

4849

4950
def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outstanding_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
5051
let summary = "check if there are outstanding writes to a buffer guarded by a mbar";
5152
let description = [{
52-
Check if there are outstanding writes to a buffer guarded by a mbar.
53+
Check if the writeState tensor has non-zero value associated with the buffer.
54+
55+
`writeState` is a tensor of 8b bitfields, where:
56+
- bit 0: 1 if the buffer is being written to
57+
- bit 1: 1 if the write is *not* hwPipelined
58+
59+
If hwPipelined is true, shift the bitfield by 1 to check the second bit - this
60+
means that the error won't be triggered if another pipelined write is outstanding.
5361
}];
5462
let arguments = (ins
5563
TTG_MemDescType:$buf,
5664
TT_Tensor:$buffers,
5765
TT_PtrLike:$writeBars,
5866
TypeAttr:$writeBarsType,
67+
TT_PtrLike:$writeState,
68+
TypeAttr:$writeStateType,
69+
I1Attr:$hwPipelined,
5970
Optional<I1>:$pred
6071
);
6172
let assemblyFormat = [{
62-
$buf `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeBars)
73+
$buf `{` $buffers `,` $writeBars `(` $writeBarsType `)` `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? `pipelined` $hwPipelined attr-dict `:` type($buf) `,` type($buffers) `,` type($writeBars) `,` type($writeState)
6374
}];
6475
let hasVerifier = 1;
6576
}
@@ -87,18 +98,49 @@ def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstan
8798
def TTI_ExperimentalMarkAsWriteOp : TTI_Op<"experimental_mark_as_write", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
8899
let summary = "mark a buffer as being written to using mbar as a guard";
89100
let description = [{
90-
Mark a buffer as being written to using mbar as a guard.
101+
Mark a buffer as being written to. It is not yet tracked by a barrier, until
102+
`commit_write_with_barrier` is called, at which point all the buffers being written
103+
to are marked as tracked by the barrier.
104+
105+
`writeState` is a tensor of 8b bitfields, where:
106+
- bit 0: 1 if the buffer is being written to
107+
- bit 1: 1 if the write is *not* hwPipelined
108+
109+
If hwPipelined is true, the write won't trigger an error if another pipelined
110+
write is executed later without waiting for the barrier.
91111
}];
92112
let arguments = (ins
93113
TTG_MemDescType:$buf,
94-
TTG_MemDescType:$mbar,
95114
TT_Tensor:$buffers,
115+
TT_PtrLike:$writeState,
116+
TypeAttr:$writeStateType,
117+
I1Attr:$hwPipelined,
118+
Optional<I1>:$pred
119+
);
120+
let assemblyFormat = [{
121+
$buf `{` $buffers `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? `pipelined` $hwPipelined attr-dict `:` type($buf) `,` type($buffers) `,` type($writeState)
122+
}];
123+
let hasVerifier = 1;
124+
}
125+
126+
127+
def TTI_ExperimentalCommitWriteWithBarrierOp : TTI_Op<"experimental_commit_write_with_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
128+
let summary = "Mark all buffers being currently written as tracked by the barrier.";
129+
let description = [{
130+
For all buffers currently marked in writeState tensor, mark them as tracked by the mbar in
131+
writeBars tensor.
132+
}];
133+
let arguments = (ins
134+
TTG_MemDescType:$mbar,
135+
TT_Tensor:$barriers,
96136
TT_PtrLike:$writeBars,
97137
TypeAttr:$writeBarsType,
138+
TT_PtrLike:$writeState,
139+
TypeAttr:$writeStateType,
98140
Optional<I1>:$pred
99141
);
100142
let assemblyFormat = [{
101-
$buf `,` $mbar `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($mbar) `,` type($buffers) `,` type($writeBars)
143+
$mbar `{` $barriers `,` $writeBars `(` $writeBarsType `)` `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($writeBars) `,` type($writeState)
102144
}];
103145
let hasVerifier = 1;
104146
}
@@ -132,13 +174,17 @@ def TTI_ExperimentalClearWriteBarrierOp : TTI_Op<"experimental_clear_write_barri
132174
}];
133175
let arguments = (ins
134176
TTG_MemDescType:$mbar,
177+
TT_Tensor:$barriers,
135178
TT_PtrLike:$writeBars,
136179
TypeAttr:$writeBarsType,
180+
TT_PtrLike:$writeState,
181+
TypeAttr:$writeStateType,
137182
Optional<I1>:$pred
138183
);
139184
let assemblyFormat = [{
140-
$mbar `{` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($writeBars)
185+
$mbar `{` $barriers `,` $writeBars `(` $writeBarsType `)` `,` $writeState `(` $writeStateType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($writeBars) `,` type($writeState)
141186
}];
187+
let hasVerifier = 1;
142188
}
143189

144190

@@ -160,6 +206,27 @@ def TTI_ExperimentalClearReadBarrierOp : TTI_Op<"experimental_clear_read_barrier
160206
let hasVerifier = 1;
161207
}
162208

209+
210+
def TTI_ExperimentalCheckBarrierWritesClearedOp : TTI_Op<"experimental_check_barrier_writes_cleared", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
211+
let summary = "verify that the barrier is not used to track any writes";
212+
let description = [{
213+
Verify that the barrier is not used to track any writes.
214+
}];
215+
let arguments = (ins
216+
TTG_MemDescType:$mbar,
217+
TT_Tensor:$barriers,
218+
TT_PtrLike:$writeBars,
219+
TypeAttr:$writeBarsType,
220+
Optional<I1>:$pred
221+
);
222+
let assemblyFormat = [{
223+
$mbar `{` $barriers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($writeBars)
224+
}];
225+
let hasVerifier = 1;
226+
}
227+
228+
229+
// TODO: Potentially resolve the naming/functionality clash with commit_write_with_barrier
163230
def TTI_ExperimentalStageWriteForCommitOp : TTI_Op<"experimental_stage_write_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
164231
let summary = "Preapre to an async copy of a buffer. Staged until commit_group is called.";
165232
let description = [{

lib/Conversion/TritonInstrumentToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ add_triton_library(TritonInstrumentToLLVM
77
TritonIR
88
TritonGPUIR
99
TritonInstrumentIR
10+
TritonNvidiaGPUIR
1011
)

0 commit comments

Comments
 (0)