Skip to content

Commit 65939a0

Browse files
authored
[PROTON] Get the context depth of profiling sessions (#6158)
1 parent 8d0bc0a commit 65939a0

File tree

12 files changed

+74
-0
lines changed

12 files changed

+74
-0
lines changed

third_party/proton/csrc/Proton.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ void initProton(pybind11::module &&m) {
7878
SessionManager::instance().addMetrics(scopeId, metrics);
7979
});
8080

81+
m.def("get_context_depth", [](size_t sessionId) {
82+
return SessionManager::instance().getContextDepth(sessionId);
83+
});
84+
8185
pybind11::bind_map<std::map<std::string, MetricValueType>>(m, "MetricMap");
8286
}
8387

third_party/proton/csrc/include/Context/Context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class ContextSource {
4343

4444
void setState(std::optional<Context> state) { ContextSource::state = state; }
4545

46+
virtual size_t getDepth() = 0;
47+
4648
protected:
4749
virtual std::vector<Context> getContextsImpl() = 0;
4850
static thread_local std::optional<Context> state;

third_party/proton/csrc/include/Context/Python.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ class PythonContextSource : public ContextSource {
1010
public:
1111
PythonContextSource() = default;
1212

13+
size_t getDepth() override;
14+
1315
private:
1416
std::vector<Context> getContextsImpl() override;
1517
};

third_party/proton/csrc/include/Context/Shadow.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class ShadowContextSource : public ContextSource, public ScopeInterface {
2525

2626
void exitScope(const Scope &scope) override;
2727

28+
size_t getDepth() override;
29+
2830
private:
2931
std::vector<Context> getContextsImpl() override;
3032

third_party/proton/csrc/include/Session/Session.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class Session {
3131

3232
void finalize(OutputFormat outputFormat);
3333

34+
size_t getContextDepth();
35+
3436
private:
3537
Session(size_t id, const std::string &path, Profiler *profiler,
3638
std::unique_ptr<ContextSource> contextSource,
@@ -88,6 +90,8 @@ class SessionManager : public Singleton<SessionManager> {
8890

8991
void deactivateAllSessions();
9092

93+
size_t getContextDepth(size_t sessionId);
94+
9195
void enterScope(const Scope &scope);
9296

9397
void exitScope(const Scope &scope);

third_party/proton/csrc/lib/Context/Python.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,6 @@ std::vector<Context> PythonContextSource::getContextsImpl() {
9494
return contexts;
9595
}
9696

97+
size_t PythonContextSource::getDepth() { return getContextsImpl().size(); }
98+
9799
} // namespace proton

third_party/proton/csrc/lib/Context/Shadow.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ std::vector<Context> ShadowContextSource::getContextsImpl() {
2626
return threadContextStack;
2727
}
2828

29+
size_t ShadowContextSource::getDepth() {
30+
initializeThreadContext();
31+
return threadContextStack.size();
32+
}
33+
2934
void ShadowContextSource::exitScope(const Scope &scope) {
3035
if (threadContextStack.empty()) {
3136
throw std::runtime_error("Context stack is empty");

third_party/proton/csrc/lib/Session/Session.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ void Session::finalize(OutputFormat outputFormat) {
6969
data->dump(outputFormat);
7070
}
7171

72+
size_t Session::getContextDepth() { return contextSource->getDepth(); }
73+
7274
std::unique_ptr<Session> SessionManager::makeSession(
7375
size_t id, const std::string &path, const std::string &profilerName,
7476
const std::string &profilerPath, const std::string &contextSourceName,
@@ -240,4 +242,10 @@ void SessionManager::setState(std::optional<Context> context) {
240242
}
241243
}
242244

245+
size_t SessionManager::getContextDepth(size_t sessionId) {
246+
std::lock_guard<std::mutex> lock(mutex);
247+
throwIfSessionNotInitialized(sessions, sessionId);
248+
return sessions[sessionId]->getContextDepth();
249+
}
250+
243251
} // namespace proton

third_party/proton/proton/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
profile,
1010
DEFAULT_PROFILE_NAME,
1111
)
12+
from . import context
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional
2+
from triton._C.libproton import proton as libproton
3+
from .flags import get_profiling_on
4+
5+
6+
def depth(session: Optional[int] = 0) -> Optional[int]:
7+
"""
8+
Get the depth of the context.
9+
10+
Args:
11+
session (int): The session ID of the profiling session. Defaults to 0.
12+
13+
Returns:
14+
depth (int or None): The depth of the context. If profiling is off, returns None.
15+
"""
16+
if not get_profiling_on():
17+
return None
18+
return libproton.get_context_depth(session)

0 commit comments

Comments
 (0)