Skip to content

Commit c8e2ad8

Browse files
authored
[PROTON] Add proton.state utility (#5110)
`state` is different from `scope` in several ways: 1. State is not recursive; each operation can have only a single state. Inner most state will overwrite the outer most state. 2. A states is a suffix, meaning that the original call path will append a state above the name of each kernel. 3. State is compatible with both Python and shadow contexts.
1 parent 1d004c0 commit c8e2ad8

File tree

22 files changed

+318
-36
lines changed

22 files changed

+318
-36
lines changed

third_party/proton/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,29 @@ proton-viewer -h
162162

163163
## Advanced features
164164

165+
### State annotation
166+
167+
In addition to `proton.scope`, we can also customize the call path of each GPU operation using `proton.state`.
168+
169+
`state` is different from `scope` in several ways:
170+
171+
1. State is not recursive; each operation can have only a single state. Inner most state will overwrite the outer most state.
172+
2. A states is a suffix, meaning that the original call path will append a state above the name of each kernel.
173+
3. State is compatible with both Python and shadow contexts.
174+
175+
The following example demonstrates a basic use of state:
176+
177+
```python
178+
with proton.scope("test"):
179+
with proton.state("state0"):
180+
with proton.scope("test0"):
181+
foo0[1,](x, y)
182+
with proton.scope("test1"):
183+
foo1[1,](x, y)
184+
```
185+
186+
The call path of `foo1` will be `test->test1->state0`.
187+
165188
### Instrumentation (experimental)
166189

167190
In addition to profiling, Proton also incorporates MLIR/LLVM based compiler instrumentation passes to get Triton level analysis

third_party/proton/csrc/Proton.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ void initProton(pybind11::module &&m) {
6464
SessionManager::instance().exitOp(Scope(scopeId, name));
6565
});
6666

67+
m.def("enter_state", [](const std::string &state) {
68+
SessionManager::instance().setState(state);
69+
});
70+
71+
m.def("exit_state",
72+
[]() { SessionManager::instance().setState(std::nullopt); });
73+
6774
m.def("add_metrics",
6875
[](size_t scopeId,
6976
const std::map<std::string, MetricValueType> &metrics) {

third_party/proton/csrc/include/Context/Context.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <limits>
66
#include <map>
77
#include <mutex>
8+
#include <optional>
89
#include <string>
910
#include <vector>
1011

@@ -31,7 +32,20 @@ class ContextSource {
3132
public:
3233
ContextSource() = default;
3334
virtual ~ContextSource() = default;
34-
virtual std::vector<Context> getContexts() = 0;
35+
36+
std::vector<Context> getContexts() {
37+
auto contexts = getContextsImpl();
38+
if (state.has_value()) {
39+
contexts.push_back(state.value());
40+
}
41+
return contexts;
42+
}
43+
44+
void setState(std::optional<Context> state) { ContextSource::state = state; }
45+
46+
protected:
47+
virtual std::vector<Context> getContextsImpl() = 0;
48+
static thread_local std::optional<Context> state;
3549
};
3650

3751
/// A scope is a context with a unique identifier.

third_party/proton/csrc/include/Context/Python.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ namespace proton {
88
/// Unwind the Python stack and early return a list of contexts.
99
class PythonContextSource : public ContextSource {
1010
public:
11-
std::vector<Context> getContexts() override;
11+
PythonContextSource() = default;
12+
13+
private:
14+
std::vector<Context> getContextsImpl() override;
1215
};
1316

1417
} // namespace proton

third_party/proton/csrc/include/Context/Shadow.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ class ShadowContextSource : public ContextSource, public ScopeInterface {
1212
public:
1313
ShadowContextSource() = default;
1414

15-
std::vector<Context> getContexts() override { return contextStack; }
16-
1715
void enterScope(const Scope &scope) override;
1816

1917
void exitScope(const Scope &scope) override;
2018

2119
private:
20+
std::vector<Context> getContextsImpl() override { return contextStack; }
2221
std::vector<Context> contextStack;
2322
};
2423

third_party/proton/csrc/include/Session/Session.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class SessionManager : public Singleton<SessionManager> {
9595
const std::map<std::string, MetricValueType> &metrics,
9696
bool aggregable);
9797

98+
void setState(std::optional<Context> context);
99+
98100
private:
99101
std::unique_ptr<Session> makeSession(size_t id, const std::string &path,
100102
const std::string &profilerName,
@@ -146,6 +148,8 @@ class SessionManager : public Singleton<SessionManager> {
146148
std::map<ScopeInterface *, size_t> scopeInterfaceCounts;
147149
// op -> active count
148150
std::map<OpInterface *, size_t> opInterfaceCounts;
151+
// context source -> active count
152+
std::map<ContextSource *, size_t> contextSourceCounts;
149153
};
150154

151155
} // namespace proton

third_party/proton/csrc/lib/Context/Context.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
namespace proton {
44

5+
/*static*/ thread_local std::optional<Context> ContextSource::state =
6+
std::nullopt;
7+
58
std::atomic<size_t> Scope::scopeIdCounter{1};
69

710
/*static*/ thread_local std::map<ThreadLocalOpInterface *, bool>

third_party/proton/csrc/lib/Context/Python.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ std::string unpackPyobject(PyObject *pyObject) {
7171

7272
} // namespace
7373

74-
std::vector<Context> PythonContextSource::getContexts() {
74+
std::vector<Context> PythonContextSource::getContextsImpl() {
7575
pybind11::gil_scoped_acquire gil;
7676

7777
PyFrameObject *frame = PyEval_GetFrame();

third_party/proton/csrc/lib/Session/Session.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ void SessionManager::activateSessionImpl(size_t sessionId) {
111111
sessions[sessionId]->activate();
112112
registerInterface<ScopeInterface>(sessionId, scopeInterfaceCounts);
113113
registerInterface<OpInterface>(sessionId, opInterfaceCounts);
114+
registerInterface<ContextSource>(sessionId, contextSourceCounts);
114115
}
115116

116117
void SessionManager::deActivateSessionImpl(size_t sessionId) {
@@ -122,6 +123,7 @@ void SessionManager::deActivateSessionImpl(size_t sessionId) {
122123
sessions[sessionId]->deactivate();
123124
unregisterInterface<ScopeInterface>(sessionId, scopeInterfaceCounts);
124125
unregisterInterface<OpInterface>(sessionId, opInterfaceCounts);
126+
unregisterInterface<ContextSource>(sessionId, contextSourceCounts);
125127
}
126128

127129
void SessionManager::removeSession(size_t sessionId) {
@@ -226,4 +228,14 @@ void SessionManager::addMetrics(
226228
}
227229
}
228230

231+
void SessionManager::setState(std::optional<Context> context) {
232+
std::shared_lock<std::shared_mutex> lock(mutex);
233+
for (auto iter : contextSourceCounts) {
234+
auto [contextSource, count] = iter;
235+
if (count > 0) {
236+
contextSource->setState(context);
237+
}
238+
}
239+
}
240+
229241
} // namespace proton

third_party/proton/proton/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# flake8: noqa
22
from .scope import scope, enter_scope, exit_scope
3+
from .state import state, enter_state, exit_state
34
from .profile import (
45
start,
56
activate,

0 commit comments

Comments
 (0)