Skip to content

Commit bfbdc55

Browse files
Reapply "[PROTON] Improve hook manager tests and fix a session_id=0 problem (#7745)" (#4960)
This reverts commit f93ddb1.
2 parents db00ded + 25b1b46 commit bfbdc55

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

third_party/proton/proton/hooks/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def register(hook: Hook, session: int) -> None:
103103

104104
@staticmethod
105105
def unregister(session: Optional[int] = None) -> None:
106-
if session and session not in HookManager.session_hooks:
106+
if session is not None and session not in HookManager.session_hooks:
107107
return
108108

109-
if not session:
109+
if session is None:
110110
for hook in HookManager.active_hooks:
111111
hook.deactivate()
112112
HookManager.active_hooks.clear()

third_party/proton/proton/hooks/instrumentation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def _get_backend_name() -> str:
123123
return "nvidia"
124124
elif backend == "hip":
125125
return "amd"
126+
elif backend == "xpu":
127+
return "intel"
126128
else:
127129
raise RuntimeError(f"Unsupported backend: {backend}")
128130

third_party/proton/test/test_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import triton.profiler as proton
33
import pathlib
44
from triton.profiler.hooks.hook import HookManager
5+
from triton.profiler.hooks.launch import LaunchHook
6+
from triton.profiler.hooks.instrumentation import InstrumentationHook
57

68

79
def test_profile_single_session(tmp_path: pathlib.Path):
@@ -117,6 +119,34 @@ def test_hook(tmp_path: pathlib.Path):
117119
assert temp_file.exists()
118120

119121

122+
def test_hook_manager(tmp_path: pathlib.Path):
123+
# Launch hook is a singleton
124+
HookManager.register(LaunchHook(), 0)
125+
HookManager.register(LaunchHook(), 0)
126+
assert len(HookManager.active_hooks) == 1
127+
assert isinstance(HookManager.active_hooks[0], LaunchHook)
128+
assert HookManager.session_hooks[0][HookManager.active_hooks[0]] is True
129+
130+
# Only unregister one session
131+
HookManager.register(LaunchHook(), 1)
132+
HookManager.unregister(0)
133+
assert len(HookManager.active_hooks) == 1
134+
HookManager.unregister(1)
135+
assert len(HookManager.active_hooks) == 0
136+
137+
# Heterogenous hooks
138+
HookManager.register(InstrumentationHook(""), 2)
139+
HookManager.register(LaunchHook(), 2)
140+
assert len(HookManager.active_hooks) == 2
141+
# Launch hook has a higher priority
142+
assert isinstance(HookManager.active_hooks[0], LaunchHook)
143+
assert isinstance(HookManager.active_hooks[1], InstrumentationHook)
144+
assert HookManager.session_hooks[2][HookManager.active_hooks[0]] is True
145+
assert HookManager.session_hooks[2][HookManager.active_hooks[1]] is True
146+
HookManager.unregister()
147+
assert len(HookManager.active_hooks) == 0
148+
149+
120150
def test_scope_metrics(tmp_path: pathlib.Path):
121151
temp_file = tmp_path / "test_scope_metrics.hatchet"
122152
session_id = proton.start(str(temp_file.with_suffix("")))

0 commit comments

Comments
 (0)