Skip to content

Commit d354fe1

Browse files
authored
[PROTON] Fix hook manager when a session is deactivated multiple times (#7743)
1 parent 029056e commit d354fe1

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

third_party/proton/proton/hooks/hook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def deactivate(session: Optional[int] = None) -> None:
7777
deactivated_hooks = set()
7878
for session in sessions:
7979
for hook in HookManager.session_hooks[session]:
80+
if hook in HookManager.active_hooks:
81+
deactivated_hooks.add(hook)
8082
HookManager.session_hooks[session][hook] = False
81-
deactivated_hooks.add(hook)
8283

8384
# Check if any other sessions rely on this hook
8485
for hook in deactivated_hooks:

third_party/proton/test/test_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import triton.profiler as proton
33
import pathlib
4+
from triton.profiler.hooks.hook import HookManager
45

56

67
def test_profile_single_session(tmp_path: pathlib.Path):
@@ -103,6 +104,14 @@ def test_hook(tmp_path: pathlib.Path):
103104
temp_file = tmp_path / "test_hook.hatchet"
104105
session_id0 = proton.start(str(temp_file.with_suffix("")), hook="triton")
105106
proton.activate(session_id0)
107+
proton.activate(session_id0)
108+
assert len(
109+
HookManager.active_hooks) == 1, ("Activate a session multiple times should maintain a single instance of hook")
110+
assert list(HookManager.session_hooks[session_id0].values())[0] is True
111+
proton.deactivate(session_id0)
112+
assert list(HookManager.session_hooks[session_id0].values())[0] is False
113+
assert len(HookManager.active_hooks) == 0
114+
# Deactivate a session multiple times should not raise an error
106115
proton.deactivate(session_id0)
107116
proton.finalize(None)
108117
assert temp_file.exists()

0 commit comments

Comments
 (0)