Skip to content

Commit 46f4bbd

Browse files
[ConSan] Add support for mmav5, checks for buffer reads (#7433)
Main functionality included in this change is a support for tracking and checking outstanding reads. With this we can catch issues like insufficient multibuffering of the mma operands - added test confirming the functionality works. This change also features slight refactor of the IR - instead of having ops per the TritonGPU op that we want to instrument, I have introduced more modular ops that do individual checks/aux data updates. This should make adding further instrumentation much easier.
1 parent 8a45291 commit 46f4bbd

File tree

14 files changed

+1402
-273
lines changed

14 files changed

+1402
-273
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,6 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc"> {
411411
);
412412
let results = (outs Arg<TT_Ptr, "", [MemAlloc<GlobalMemory>]>:$result);
413413

414-
let builders = [
415-
OpBuilder<(ins "Type":$result, "int32_t":$nbytes, "int32_t":$alignment),
416-
[{ build($_builder, $_state, result,
417-
$_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]>
418-
];
419-
420414
let assemblyFormat = [{attr-dict `:` qualified(type($result))}];
421415
}
422416

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,32 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
77
include "mlir/IR/OpBase.td"
88
include "mlir/Interfaces/SideEffectInterfaces.td"
99

10+
//
11+
// Interfaces
12+
//
13+
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
14+
15+
//
16+
// Ops
17+
//
18+
1019
class TTI_Op<string mnemonic, list<Trait> traits = []> :
1120
Op<TritonInstrument_Dialect, mnemonic, traits> {
1221
}
1322

14-
// Define an array of pointers to shared memory buffers
23+
def TTI_ExperimentalAssertInThreadOp : TTI_Op<"experimental_assert_in_thread", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
24+
let summary = "assert the condition within the current thread";
25+
let description = [{
26+
Assert that the condition is true given all the values are available in the current thread.
27+
If the condition is false, the message is printed, and the program is aborted.
28+
If check_any is true, any of the values in the condition must be true. Otherwise, all the
29+
values in the condition must be true.
30+
}];
31+
let arguments = (ins I1Tensor:$condition, StrAttr:$message, BoolAttr:$check_any);
32+
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
33+
}
34+
35+
1536
def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_pointers", [Pure]> {
1637
let summary = "definte an array of pointers to shared memory buffers";
1738
let description = [{
@@ -24,59 +45,119 @@ def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_
2445
}];
2546
}
2647

27-
// Check if writing to a buffer guarded by a mbar is valid
28-
def TTI_ExperimentalCheckAsyncWriteWithMbarSharedOp : TTI_Op<"experimental_check_async_write_with_mbar_shared", [Pure]> {
29-
let summary = "check if writing to a buffer guarded by a mbar is valid";
48+
49+
def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outstanding_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
50+
let summary = "check if there are outstanding writes to a buffer guarded by a mbar";
3051
let description = [{
31-
Check if writing to a shared memory buffer guarded by a mbar is valid.
32-
Update the buffer state and assert if the buffer is being read or written.
52+
Check if there are outstanding writes to a buffer guarded by a mbar.
3353
}];
3454
let arguments = (ins
35-
TTG_MemDescType:$buffer,
36-
TTG_MemDescType:$mbar,
55+
TTG_MemDescType:$buf,
3756
TT_Tensor:$buffers,
38-
TT_Tensor:$states,
39-
TT_Tensor:$barriers
57+
TT_PtrLike:$writeBars,
58+
TypeAttr:$writeBarsType,
59+
Optional<I1>:$pred
4060
);
41-
let results = (outs
42-
TT_Tensor:$outStates,
43-
TT_Tensor:$outBarriers
61+
let assemblyFormat = [{
62+
$buf `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeBars)
63+
}];
64+
let hasVerifier = 1;
65+
}
66+
67+
68+
def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstanding_reads", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
69+
let summary = "check if there are outstanding reads from a buffer guarded by a mbar";
70+
let description = [{
71+
Check if there are outstanding reads from a buffer guarded by a mbar.
72+
}];
73+
let arguments = (ins
74+
TTG_MemDescType:$buf,
75+
TT_Tensor:$buffers,
76+
TT_PtrLike:$readBars,
77+
TypeAttr:$readBarsType,
78+
Optional<I1>:$pred
4479
);
4580
let assemblyFormat = [{
46-
$buffer `,` $mbar `{` $buffers `,` $states `,` $barriers `}` attr-dict `:` type($buffer) `,` type($mbar) `,` type($buffers) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
81+
$buf `{` $buffers `,` $readBars `(` $readBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($readBars)
4782
}];
48-
let builders = [
49-
OpBuilder<(ins "Value":$buffer, "Value":$mbar, "Value":$buffers, "Value":$states, "Value":$barriers),[{
50-
build($_builder, $_state, {states.getType(), barriers.getType()}, buffer, mbar, buffers, states, barriers);
51-
}]>
52-
];
83+
let hasVerifier = 1;
5384
}
5485

55-
def TTI_ExperimentalCheckWaitMbarOp : TTI_Op<"experimental_check_wait_mbar", [Pure]> {
56-
let summary = "check if waiting on a mbar is valid and update the barrier state";
86+
87+
def TTI_ExperimentalMarkAsWriteOp : TTI_Op<"experimental_mark_as_write", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
88+
let summary = "mark a buffer as being written to using mbar as a guard";
5789
let description = [{
58-
Check if waiting on a mbar is valid and update the barrier state.
90+
Mark a buffer as being written to using mbar as a guard.
5991
}];
6092
let arguments = (ins
93+
TTG_MemDescType:$buf,
6194
TTG_MemDescType:$mbar,
95+
TT_Tensor:$buffers,
96+
TT_PtrLike:$writeBars,
97+
TypeAttr:$writeBarsType,
98+
Optional<I1>:$pred
99+
);
100+
let assemblyFormat = [{
101+
$buf `,` $mbar `{` $buffers `,` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($mbar) `,` type($buffers) `,` type($writeBars)
102+
}];
103+
let hasVerifier = 1;
104+
}
105+
106+
107+
def TTI_ExperimentalMarkAsReadOp : TTI_Op<"experimental_mark_as_read", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
108+
let summary = "mark a buffer as being read from using mbar as a guard";
109+
let description = [{
110+
Mark a buffer as being read from using mbar as a guard.
111+
}];
112+
let arguments = (ins
113+
TTG_MemDescType:$buf,
114+
TTG_MemDescType:$mbar,
115+
TT_Tensor:$buffers,
62116
TT_Tensor:$barriers,
63-
TT_Tensor:$states
64-
);
117+
TT_PtrLike:$readBars,
118+
TypeAttr:$readBarsType,
119+
Optional<I1>:$pred
120+
);
121+
let assemblyFormat = [{
122+
$buf `,` $mbar `{` $buffers `,` $barriers `,` $readBars `(` $readBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($mbar) `,` type($buffers) `,` type($barriers) `,` type($readBars)
123+
}];
124+
let hasVerifier = 1;
125+
}
65126

66-
let results = (outs
67-
TT_Tensor:$outStates,
68-
TT_Tensor:$outBarriers);
69127

128+
def TTI_ExperimentalClearWriteBarrierOp : TTI_Op<"experimental_clear_write_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
129+
let summary = "clear the write state for buffers being guarded by an mbar";
130+
let description = [{
131+
Clear the write state for buffers being guarded by an mbar.
132+
}];
133+
let arguments = (ins
134+
TTG_MemDescType:$mbar,
135+
TT_PtrLike:$writeBars,
136+
TypeAttr:$writeBarsType,
137+
Optional<I1>:$pred
138+
);
70139
let assemblyFormat = [{
71-
$mbar `{` $states `,` $barriers `}` attr-dict `:` type($mbar) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
140+
$mbar `{` $writeBars `(` $writeBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($writeBars)
72141
}];
142+
}
73143

74-
let builders = [
75-
OpBuilder<(ins "Value":$mbar, "Value":$barriers, "Value":$states),
76-
[{
77-
build($_builder, $_state, {states.getType(), barriers.getType()}, mbar, barriers, states);
78-
}]>];
79144

145+
def TTI_ExperimentalClearReadBarrierOp : TTI_Op<"experimental_clear_read_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
146+
let summary = "clear the read state for buffers being guarded by an mbar";
147+
let description = [{
148+
Clear the read state for buffers being guarded by an mbar.
149+
}];
150+
let arguments = (ins
151+
TTG_MemDescType:$mbar,
152+
TT_Tensor:$barriers,
153+
TT_PtrLike:$readBars,
154+
TypeAttr:$readBarsType,
155+
Optional<I1>:$pred
156+
);
157+
let assemblyFormat = [{
158+
$mbar `{` $barriers `,` $readBars `(` $readBarsType `)` `}` (`,` $pred^)? attr-dict `:` type($mbar) `,` type($barriers) `,` type($readBars)
159+
}];
160+
let hasVerifier = 1;
80161
}
81162

82163
#endif // TRITONINSTRUMENT_OPS
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
2+
3+
namespace mlir::triton::instrument {
4+
5+
Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc,
6+
Value tensor, RankedTensorType tensorType);
7+
Operation *createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc,
8+
RankedTensorType tensorType);
9+
Value expandOuterSlicedDim(OpBuilder &b, Location loc, Value tensor);
10+
TypedValue<RankedTensorType> createConstIntTensor(OpBuilder &builder,
11+
Location loc, int val,
12+
RankedTensorType tensorType);
13+
14+
} // namespace mlir::triton::instrument

0 commit comments

Comments
 (0)