Skip to content

Commit b3cf593

Browse files
[CONSAN] Convert consan instrumentation opcodes to function calls (#8559)
This change aims reduces compilation time of kernels instrumented with ConSan. It moves the implementation of the primitives from MLIR ops to function calls. Only few of crucial ops are left as operations of TritonInstrument dialect. TritonInstrument::FunctionBuilder class is introduced as a class used for creating and managing instrumentation functions.
1 parent a1316c4 commit b3cf593

File tree

12 files changed

+2408
-2965
lines changed

12 files changed

+2408
-2965
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#ifndef TRITONINSTRUMENT_FUNCTIONBUILDER_H
2+
#define TRITONINSTRUMENT_FUNCTIONBUILDER_H
3+
4+
#include "triton/Dialect/TritonInstrument/IR/Utility.h"
5+
6+
#include <string>
7+
#include <variant>
8+
9+
#include "llvm/ADT/ArrayRef.h"
10+
#include "llvm/ADT/StringRef.h"
11+
12+
namespace mlir {
13+
class ImplicitLocOpBuilder;
14+
class ModuleOp;
15+
class Operation;
16+
class RankedTensorType;
17+
class Type;
18+
class Value;
19+
} // namespace mlir
20+
21+
namespace mlir::triton {
22+
class FuncOp;
23+
24+
namespace instrument {
25+
26+
class ManglingArgs {
27+
public:
28+
using Arg = std::variant<Type, int, std::string>;
29+
30+
ManglingArgs() = default;
31+
ManglingArgs(const ManglingArgs &) = default;
32+
ManglingArgs(ManglingArgs &&) = default;
33+
ManglingArgs &operator=(const ManglingArgs &) = default;
34+
ManglingArgs &operator=(ManglingArgs &&) = default;
35+
36+
ManglingArgs(std::initializer_list<Arg> args) : args(args) {}
37+
38+
~ManglingArgs() = default;
39+
40+
template <typename T> void append(T arg) { args.push_back(arg); }
41+
42+
template <typename T> void append(ArrayRef<T> arg) {
43+
for (auto &a : arg) {
44+
args.push_back(a);
45+
}
46+
}
47+
48+
void append(ManglingArgs &other) {
49+
args.append(other.args.begin(), other.args.end());
50+
}
51+
52+
std::string mangleArg(Arg arg) const {
53+
if (auto type = std::get_if<Type>(&arg)) {
54+
auto hash = static_cast<uint64_t>(mlir::hash_value(*type));
55+
return std::string("_T") + llvm::utohexstr(hash);
56+
} else if (auto intVal = std::get_if<int>(&arg)) {
57+
return std::string("_I") + std::to_string(*intVal);
58+
} else if (auto stringVal = std::get_if<std::string>(&arg)) {
59+
return *stringVal;
60+
}
61+
llvm_unreachable("Unsupported argument type");
62+
}
63+
64+
std::string mangle(std::string baseName, int numWarps) const {
65+
std::string name = "__triton_consan_";
66+
name += baseName;
67+
name += "_nw" + std::to_string(numWarps);
68+
for (auto arg : args)
69+
name += mangleArg(arg);
70+
return name;
71+
}
72+
73+
private:
74+
SmallVector<Arg> args;
75+
};
76+
77+
/// Utility to mangle helper function names produced by the instrumentation
78+
/// passes. The mangled name encodes the base name, number of warps and the
79+
/// participating types.
80+
std::string mangleInstrumentHelperName(const std::string &baseName,
81+
int numWarps,
82+
llvm::ArrayRef<Type> types);
83+
84+
class FunctionBuilder {
85+
public:
86+
FunctionBuilder(ModuleOp module, AuxDataMap &auxData)
87+
: module(module), auxData(auxData) {}
88+
89+
// setWaiting: mark the base thread as waiting on the given barrier phase and
90+
// record that phase for deadlock detection.
91+
void createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread,
92+
Value phase, Value pred, Operation *insertPoint);
93+
// clearWaiting: clear the waiting flag and stored phase for the base thread.
94+
void createClearWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread,
95+
Value pred, Operation *insertPoint);
96+
// checkAllActiveWaiting: assert that not all active threads are waiting on
97+
// matching barrier phases.
98+
void createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, int activeMask,
99+
Value pred, Operation *insertPoint);
100+
// initBarrierState: Initialize the tracked barrier state to phase 0 and set
101+
// both the initial and current arrival counts.
102+
void createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar,
103+
int count, Operation *insertPoint);
104+
// verifyBarrierArrive: Check that applying the arrive count would not drive
105+
// the tracked current count negative. Triggers an assertion on failure.
106+
void createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbar,
107+
int count, Value pred,
108+
Operation *insertPoint);
109+
// updateBarrierState: Apply an arrive count to the tracked barrier state,
110+
// toggling the phase when the count reaches zero and reloading the current
111+
// count from the initial count.
112+
void createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar,
113+
int count, Value pred,
114+
Operation *insertPoint);
115+
// setWriteVisibility: Set the write visibility for a buffer. Marks the buffer
116+
// as visible to the threads set in threadMask. Clears out any other threads
117+
// from the visibility bitmask. We know this is safe because there cannot be
118+
// outstanding writes to this buffer at this point.
119+
void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
120+
uint64_t threadMask, Value pred,
121+
MemType memType, Operation *insertPoint);
122+
// setReadVisibility: add the threads set in threadMask to the buffer's read
123+
// visibility bitmask.
124+
void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
125+
uint64_t threadMask, Value pred,
126+
MemType memType, Operation *insertPoint);
127+
// clearWriteTracking: clear all the information about threads writing to a
128+
// buffer.
129+
void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf,
130+
Value pred, MemType memType,
131+
Operation *insertPoint);
132+
// clearReadVisibility: clear the read visibility for a buffer.
133+
void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
134+
Value pred, MemType memType,
135+
Operation *insertPoint);
136+
// clearReadTracking: clear the read tracking for a buffer.
137+
void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf,
138+
Value pred, MemType memType,
139+
Operation *insertPoint);
140+
// trackVisibleWrites: snapshot buffers currently visible to the thread into
141+
// the tracking table for a barrier.
142+
void createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar,
143+
int thread, Value pred, MemType memType,
144+
Operation *insertPoint);
145+
// trackVisibleReads: snapshot buffers currently visible to the thread into
146+
// the read tracking table for a barrier.
147+
void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
148+
int thread, Value pred, MemType memType,
149+
Operation *insertPoint);
150+
// transferVisibleWrites: transfer write visibility tracked by a barrier to
151+
// all threads in threadMask.
152+
void createTransferVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar,
153+
uint64_t threadMask, Value pred,
154+
MemType memType, Operation *insertPoint);
155+
// transferVisibleReads: transfer read visibility tracked by a barrier to all
156+
// threads in threadMask.
157+
void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
158+
uint64_t threadMask, Value pred,
159+
MemType memType, Operation *insertPoint);
160+
// verifyWriteVisibility: ensure the thread either sees the latest write or no
161+
// other thread is writing the buffer.
162+
void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
163+
int thread, StringRef operandName,
164+
Value pred, MemType memType,
165+
Operation *insertPoint);
166+
// verifyReadVisibility: ensure all reads from the buffer are visible to the
167+
// thread.
168+
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
169+
int thread, StringRef operandName,
170+
Value pred, MemType memType,
171+
Operation *insertPoint);
172+
// copyWriteVisibility: replicate the write visibility bit of sourceThread to
173+
// every destination thread in destMask.
174+
void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
175+
uint64_t destMask, Value pred,
176+
MemType memType, Operation *insertPoint);
177+
// copyReadVisibility: replicate the read visibility row of sourceThread to
178+
// every destination thread in destMask.
179+
void createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
180+
uint64_t destMask, Value pred,
181+
MemType memType, Operation *insertPoint);
182+
// stageAccessForCommit: mark the buffer as staged (value -1) in the
183+
// outstanding commit table for this thread.
184+
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
185+
int thread, Value pred, ValueType buffers,
186+
ValueType outstandingCommits,
187+
Operation *insertPoint);
188+
// commitAccesses: convert staged entries to 1 and increment outstanding
189+
// commits greater than zero for the committing thread.
190+
void createCommitAccessesCall(ImplicitLocOpBuilder &b, int thread, Value pred,
191+
ValueType outstandingCommits,
192+
Operation *insertPoint);
193+
// clearOutstandingCommitsTransferWrites: clear entries farther than
194+
// outstandingNum from the thread and set write visibility for threads in
195+
// transferThreadMask.
196+
void createClearOutstandingCommitsTransferWritesCall(
197+
ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask,
198+
int outstandingNum, Value pred, ValueType outstandingCommits,
199+
ValueType writeVisibility, Operation *insertPoint);
200+
// clearOutstandingCommitsTransferReads: clear entries farther than
201+
// outstandingNum from the thread and set read visibility for threads in
202+
// transferThreadMask.
203+
void createClearOutstandingCommitsTransferReadsCall(
204+
ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask,
205+
int outstandingNum, Value pred, ValueType outstandingCommits,
206+
ValueType readVisibility, Operation *insertPoint);
207+
// checkOutstandingCommits: assert that the outstanding commit row for the
208+
// buffer is zero before the access described by pendingAccessType.
209+
void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf,
210+
int thread,
211+
StringRef pendingAccessType,
212+
Value pred, ValueType buffers,
213+
ValueType outstandingCommits,
214+
Operation *insertPoint);
215+
216+
private:
217+
ModuleOp module;
218+
AuxDataMap &auxData;
219+
};
220+
221+
} // namespace instrument
222+
} // namespace mlir::triton
223+
224+
#endif

0 commit comments

Comments
 (0)