Skip to content

Commit 39f3c20

Browse files
[ConSan] Support for WGMMA. Checks on non-async shmem and tmem accesses (#7712)
This PR covers: * Adding support for wgmma checks (ops marked as read until wgmma_wait) * Adding checks on any shmem and tmem access - this finalizes the basic support for pipelined kernels * Slight refactor of IR - op and attribute names
1 parent 8e5db20 commit 39f3c20

File tree

9 files changed

+974
-325
lines changed

9 files changed

+974
-325
lines changed

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ include "mlir/IR/EnumAttr.td"
66
def TT_MemTypeAttr : I32EnumAttr<
77
"MemType", "",
88
[
9-
I32EnumAttrCase<"SHARED", 0, "shared">,
10-
I32EnumAttrCase<"TENSOR", 1, "tensor">,
9+
I32EnumAttrCase<"SHARED_MEM", 0, "shared_mem">,
10+
I32EnumAttrCase<"TENSOR_MEM", 1, "tensor_mem">,
1111
]> {
1212
let cppNamespace = "::mlir::triton::instrument";
1313
}

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def TTI_ExperimentalBufferPointersOp : TTI_Op<"experimental_buffer_pointers", [P
4747
}
4848

4949

50-
def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outstanding_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
50+
def TTI_ExperimentalCheckWriteStateOp : TTI_Op<"experimental_check_write_state", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
5151
let summary = "check if there are outstanding writes to a buffer guarded by a mbar";
5252
let description = [{
5353
Check if the writeState tensor has non-zero value associated with the buffer.
@@ -76,7 +76,7 @@ def TTI_ExperimentalCheckOutstandingWritesOp : TTI_Op<"experimental_check_outsta
7676
}
7777

7878

79-
def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstanding_reads", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
79+
def TTI_ExperimentalCheckReadBarriersOp : TTI_Op<"experimental_check_read_barriers", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
8080
let summary = "check if there are outstanding reads from a buffer guarded by a mbar";
8181
let description = [{
8282
Check if there are outstanding reads from a buffer guarded by a mbar.
@@ -95,8 +95,8 @@ def TTI_ExperimentalCheckOutstandingReadsOp : TTI_Op<"experimental_check_outstan
9595
}
9696

9797

98-
def TTI_ExperimentalMarkAsWriteOp : TTI_Op<"experimental_mark_as_write", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
99-
let summary = "mark a buffer as being written to using mbar as a guard";
98+
def TTI_ExperimentalSetWriteStateOp : TTI_Op<"experimental_set_write_state", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
99+
let summary = "mark a buffer as being written in writeState tensor";
100100
let description = [{
101101
Mark a buffer as being written to. It is not yet tracked by a barrier, until
102102
`commit_write_with_barrier` is called, at which point all the buffers being written
@@ -146,7 +146,7 @@ def TTI_ExperimentalCommitWriteWithBarrierOp : TTI_Op<"experimental_commit_write
146146
}
147147

148148

149-
def TTI_ExperimentalMarkAsReadOp : TTI_Op<"experimental_mark_as_read", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
149+
def TTI_ExperimentalSetReadBarrierOp : TTI_Op<"experimental_set_read_barrier", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
150150
let summary = "mark a buffer as being read from using mbar as a guard";
151151
let description = [{
152152
Mark a buffer as being read from using mbar as a guard.
@@ -226,72 +226,70 @@ def TTI_ExperimentalCheckBarrierWritesClearedOp : TTI_Op<"experimental_check_bar
226226
}
227227

228228

229-
// TODO: Potentially resolve the naming/functionality clash with commit_write_with_barrier
230-
def TTI_ExperimentalStageWriteForCommitOp : TTI_Op<"experimental_stage_write_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
231-
let summary = "Preapre to an async copy of a buffer. Staged until commit_group is called.";
229+
def TTI_ExperimentalStageAccessForCommitOp : TTI_Op<"experimental_stage_access_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
230+
let summary = "";
232231
let description = [{
233-
Preapre to an async copy of a buffer. Staged until commit_group is called. The implementation will write `-1` to the
234-
`write_commits` tensor under the indices corresponding to the buffer.
232+
For operations that use `outstanding` to track the number of outstanding commits (rather than mbarriers),
233+
mark the buffer as being accessed, but not commited yet, by marking it with `-1`.
235234
}];
236235
let arguments = (ins
237236
TTG_MemDescType:$buf,
238237
TT_Tensor:$buffers,
239-
TT_PtrLike:$writeCommits,
240-
TypeAttr:$writeCommitsType,
238+
TT_PtrLike:$outstandingCommits,
239+
TypeAttr:$outstandingCommitsType,
241240
Optional<I1>:$pred
242241
);
243242
let assemblyFormat = [{
244-
$buf `{` $buffers `,` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeCommits)
243+
$buf `{` $buffers `,` $outstandingCommits `(` $outstandingCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($outstandingCommits)
245244
}];
246-
// let hasVerifier = 1;
245+
let hasVerifier = 1;
247246
}
248247

249-
def TTI_ExperimentalCommitWritesOp : TTI_Op<"experimental_commit_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
250-
let summary = "Commit all the staged writes for all the buffers.";
248+
def TTI_ExperimentalCommitAccessesOp : TTI_Op<"experimental_commit_accesses", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
249+
let summary = "Commit all the staged accesses for all the buffers.";
251250
let description = [{
252-
Commit all the staged writes for all the buffers.
251+
Commit all the staged accesses for all the buffers.
253252
}];
254253
let arguments = (ins
255-
TT_PtrLike:$writeCommits,
256-
TypeAttr:$writeCommitsType,
254+
TT_PtrLike:$outstandingCommits,
255+
TypeAttr:$outstandingCommitsType,
257256
Optional<I1>:$pred);
258257
let assemblyFormat = [{
259-
`{` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($writeCommits)
258+
`{` $outstandingCommits `(` $outstandingCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($outstandingCommits)
260259
}];
261-
// let hasVerifier = 1;
262260
}
263261

264-
def TTI_ExperimentalClearWriteCommitsOp : TTI_Op<"experimental_clear_write_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
265-
let summary = "Clear all the write commits more distant than `outstandingNum.";
262+
def TTI_ExperimentalClearOutstandingCommitsOp : TTI_Op<"experimental_clear_outstanding_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
263+
let summary = "Clear all the outstanding commits more distant than `outstandingNum.";
266264
let description = [{
267-
Clear all the write commits more distant than `outstandingNum` from the current thread.
265+
Clear all the outstanding commits more distant than `outstandingNum` from the current thread.
268266
}];
269267
let arguments = (ins
270-
TT_PtrLike:$writeCommits,
271-
TypeAttr:$writeCommitsType,
268+
TT_PtrLike:$outstandingCommits,
269+
TypeAttr:$outstandingCommitsType,
272270
I32Attr:$outstandingNum,
273271
Optional<I1>:$pred);
274272
let assemblyFormat = [{
275-
`{` $writeCommits `(` $writeCommitsType `)` `}` `,` $outstandingNum (`,` $pred^)? attr-dict `:` type($writeCommits)
273+
`{` $outstandingCommits `(` $outstandingCommitsType `)` `}` `,` $outstandingNum (`,` $pred^)? attr-dict `:` type($outstandingCommits)
276274
}];
277-
// let hasVerifier = 1;
278275
}
279276

280-
def TTI_ExperimentalCheckWriteCommitOp : TTI_Op<"experimental_check_write_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
281-
let summary = "Check if the buffer has an outstanding write commit.";
277+
def TTI_ExperimentalCheckOutstandingCommitsOp : TTI_Op<"experimental_check_outstanding_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
278+
let summary = "Check if the buffer has an outstanding commit.";
282279
let description = [{
283-
Check if the buffer has an outstanding write commit.
280+
Check if the buffer has an outstanding commit.
284281
}];
285282
let arguments = (ins
286283
TTG_MemDescType:$buf,
287284
TT_Tensor:$buffers,
288-
TT_PtrLike:$writeCommits,
289-
TypeAttr:$writeCommitsType,
285+
TT_PtrLike:$outstandingCommits,
286+
TypeAttr:$outstandingCommitsType,
287+
StrAttr:$pendingAccessType,
290288
Optional<I1>:$pred);
291289
let assemblyFormat = [{
292-
$buf `{` $buffers `,` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeCommits)
290+
$buf `{` $buffers `,` $outstandingCommits `(` $outstandingCommitsType `)` `}` `,` $pendingAccessType (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($outstandingCommits)
293291
}];
294-
// let hasVerifier = 1;
292+
let hasVerifier = 1;
295293
}
296294

297295
#endif // TRITONINSTRUMENT_OPS

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
6363
while (auto callLoc = dyn_cast<CallSiteLoc>(loc))
6464
loc = callLoc.getCallee();
6565

66+
while (auto nameLoc = dyn_cast<NameLoc>(loc))
67+
loc = nameLoc.getChildLoc();
68+
6669
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
6770
file = fileLineColLoc.getFilename();
6871
line = fileLineColLoc.getLine();

0 commit comments

Comments
 (0)