Skip to content

Commit cb30573

Browse files
[ConSan] Concurrency Sanitizer - initial scaffolding and introduction of TritonInstrument dialect (#7157)
ConcurrencySanitizer is a mechanism that is intended to help with debugging concurrency issues in Triton and Gluon. It consists of: 1. `TritonInstrumentDialect` (`tti`) intended to be used for all instrumentation purposes. The reason of introducing new dialect is to make it easy to move all the functionality away from triton if needed, and to clearly distinguish between functional and instrumentation opcodes in IR. 2. ConcurrencySanitizer pass that adds instrumentation to the IR during lowering to LLVMIR 3. Set of TTG Ops creating and updating auxiliary data structures and performing checks This change is an initial implementation of base functionality to enable checks in simple test and catch intentionally introduced issues. Current version only supports modules with single function, as aux tensors are passed around by value. This will be changed in the future with tensors being passed through global scratch and multi function support.
1 parent 526d168 commit cb30573

File tree

33 files changed

+1048
-4
lines changed

33 files changed

+1048
-4
lines changed

bin/RegisterTritonDialects.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
77
#include "triton/Dialect/Triton/IR/Dialect.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
910
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1011

1112
// Below headers will allow registration to ROCm passes
@@ -15,6 +16,7 @@
1516

1617
#include "triton/Dialect/Triton/Transforms/Passes.h"
1718
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
19+
#include "triton/Dialect/TritonInstrument/Transforms/Passes.h"
1820
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
1921

2022
#include "nvidia/hopper/include/Transforms/Passes.h"
@@ -47,6 +49,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4749
mlir::triton::registerTritonPasses();
4850
mlir::triton::gpu::registerTritonGPUPasses();
4951
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
52+
mlir::triton::instrument::registerTritonInstrumentPasses();
5053
mlir::test::registerTestAliasPass();
5154
mlir::test::registerTestAlignmentPass();
5255
mlir::test::registerAMDTestAlignmentPass();
@@ -95,9 +98,10 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9598
registry.insert<
9699
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
97100
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
98-
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
99-
mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect,
100-
mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
101+
mlir::triton::gpu::TritonGPUDialect,
102+
mlir::triton::instrument::TritonInstrumentDialect,
103+
mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect,
104+
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
101105
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
102106
mlir::triton::amdgpu::TritonAMDGPUDialect,
103107
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect>();

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
102102
const TargetInfoBase &targetInfo,
103103
PatternBenefit benefit);
104104

105+
void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter,
106+
const TargetInfoBase &targetInfo,
107+
RewritePatternSet &patterns,
108+
PatternBenefit benefit);
109+
105110
} // namespace triton
106111
} // namespace mlir
107112

include/triton/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(Triton)
22
add_subdirectory(TritonGPU)
33
add_subdirectory(TritonNvidiaGPU)
4+
add_subdirectory(TritonInstrument)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
2+
3+
set(LLVM_TARGET_DEFINITIONS TritonInstrumentDialect.td)
4+
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=tti)
5+
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=tti)
6+
add_mlir_doc(TritonInstrumentDialect TritonInstrumentDialect dialects/ -gen-dialect-doc)
7+
8+
set(LLVM_TARGET_DEFINITIONS TritonInstrumentOps.td)
9+
mlir_tablegen(Ops.h.inc -gen-op-decls)
10+
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
11+
add_mlir_doc(TritonInstrumentOps TritonInstrumentOps dialects/ -gen-op-doc)
12+
13+
add_public_tablegen_target(TritonInstrumentTableGen)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_
2+
#define TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_
3+
4+
// TritonInstrument depends on Triton and TritonGPU
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
6+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
8+
#define GET_OP_CLASSES
9+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h.inc"
10+
#include "triton/Dialect/TritonInstrument/IR/Ops.h.inc"
11+
12+
#endif // TRITON_DIALECT_TRITONINSTRUMENT_IR_DIALECT_H_
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef TRITONINSTRUMENT_DIALECT
2+
#define TRITONINSTRUMENT_DIALECT
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def TritonInstrument_Dialect : Dialect {
7+
let name = "tti";
8+
let cppNamespace = "::mlir::triton::instrument";
9+
}
10+
11+
#endif // TRITONINSTRUMENT_DIALECT
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef TRITONINSTRUMENT_OPS
2+
#define TRITONINSTRUMENT_OPS
3+
4+
include "triton/Dialect/TritonInstrument/IR/TritonInstrumentDialect.td"
5+
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
6+
include "triton/Dialect/Triton/IR/TritonTypes.td"
7+
include "mlir/IR/OpBase.td"
8+
include "mlir/Interfaces/SideEffectInterfaces.td"
9+
10+
class TTI_Op<string mnemonic, list<Trait> traits = []> :
11+
Op<TritonInstrument_Dialect, mnemonic, traits> {
12+
}
13+
14+
// Define an array of pointers to shared memory buffers
15+
def TTI_ExperimentalSharedBufferPointersOp : TTI_Op<"experimental_shared_buffer_pointers", [Pure]> {
16+
let summary = "definte an array of pointers to shared memory buffers";
17+
let description = [{
18+
Create a tensor of pointers to shared memory buffers.
19+
}];
20+
let arguments = (ins DenseI32ArrayAttr:$offsets);
21+
let results = (outs TT_Tensor:$result);
22+
let assemblyFormat = [{
23+
attr-dict `:` type($result)
24+
}];
25+
}
26+
27+
// Check if writing to a buffer guarded by a mbar is valid
28+
def TTI_ExperimentalCheckAsyncWriteWithMbarSharedOp : TTI_Op<"experimental_check_async_write_with_mbar_shared", [Pure]> {
29+
let summary = "check if writing to a buffer guarded by a mbar is valid";
30+
let description = [{
31+
Check if writing to a shared memory buffer guarded by a mbar is valid.
32+
Update the buffer state and assert if the buffer is being read or written.
33+
}];
34+
let arguments = (ins
35+
TTG_MemDescType:$buffer,
36+
TTG_MemDescType:$mbar,
37+
TT_Tensor:$buffers,
38+
TT_Tensor:$states,
39+
TT_Tensor:$barriers
40+
);
41+
let results = (outs
42+
TT_Tensor:$outStates,
43+
TT_Tensor:$outBarriers
44+
);
45+
let assemblyFormat = [{
46+
$buffer `,` $mbar `{` $buffers `,` $states `,` $barriers `}` attr-dict `:` type($buffer) `,` type($mbar) `,` type($buffers) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
47+
}];
48+
let builders = [
49+
OpBuilder<(ins "Value":$buffer, "Value":$mbar, "Value":$buffers, "Value":$states, "Value":$barriers),[{
50+
build($_builder, $_state, {states.getType(), barriers.getType()}, buffer, mbar, buffers, states, barriers);
51+
}]>
52+
];
53+
}
54+
55+
def TTI_ExperimentalCheckWaitMbarOp : TTI_Op<"experimental_check_wait_mbar", [Pure]> {
56+
let summary = "check if waiting on a mbar is valid and update the barrier state";
57+
let description = [{
58+
Check if waiting on a mbar is valid and update the barrier state.
59+
}];
60+
let arguments = (ins
61+
TTG_MemDescType:$mbar,
62+
TT_Tensor:$barriers,
63+
TT_Tensor:$states
64+
);
65+
66+
let results = (outs
67+
TT_Tensor:$outStates,
68+
TT_Tensor:$outBarriers);
69+
70+
let assemblyFormat = [{
71+
$mbar `{` $states `,` $barriers `}` attr-dict `:` type($mbar) `,` type($states) `,` type($barriers) `->` type($outStates) `,` type($outBarriers)
72+
}];
73+
74+
let builders = [
75+
OpBuilder<(ins "Value":$mbar, "Value":$barriers, "Value":$states),
76+
[{
77+
build($_builder, $_state, {states.getType(), barriers.getType()}, mbar, barriers, states);
78+
}]>];
79+
80+
}
81+
82+
#endif // TRITONINSTRUMENT_OPS
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonInstrument)
3+
add_public_tablegen_target(TritonInstrumentTransformsIncGen)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TRITON_DIALECT_TRITONINSTRUMENT_TRANSFORMS_PASSES_H_
2+
#define TRITON_DIALECT_TRITONINSTRUMENT_TRANSFORMS_PASSES_H_
3+
4+
#include "mlir/Pass/Pass.h"
5+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
6+
7+
namespace mlir {
8+
namespace triton {
9+
namespace instrument {
10+
11+
// Generate the pass class declarations.
12+
#define GEN_PASS_DECL
13+
#include "triton/Dialect/TritonInstrument/Transforms/Passes.h.inc"
14+
15+
/// Generate the code for registering passes.
16+
#define GEN_PASS_REGISTRATION
17+
#include "triton/Dialect/TritonInstrument/Transforms/Passes.h.inc"
18+
19+
} // namespace instrument
20+
} // namespace triton
21+
} // namespace mlir
22+
#endif

0 commit comments

Comments
 (0)