|
2 | 2 | import triton.profiler as proton |
3 | 3 | import pathlib |
4 | 4 | from triton.profiler.hooks.hook import HookManager |
| 5 | +from triton.profiler.hooks.launch import LaunchHook |
| 6 | +from triton.profiler.hooks.instrumentation import InstrumentationHook |
5 | 7 |
|
6 | 8 |
|
7 | 9 | def test_profile_single_session(tmp_path: pathlib.Path): |
@@ -117,6 +119,34 @@ def test_hook(tmp_path: pathlib.Path): |
117 | 119 | assert temp_file.exists() |
118 | 120 |
|
119 | 121 |
|
| 122 | +def test_hook_manager(tmp_path: pathlib.Path): |
| 123 | + # Launch hook is a singleton |
| 124 | + HookManager.register(LaunchHook(), 0) |
| 125 | + HookManager.register(LaunchHook(), 0) |
| 126 | + assert len(HookManager.active_hooks) == 1 |
| 127 | + assert isinstance(HookManager.active_hooks[0], LaunchHook) |
| 128 | + assert HookManager.session_hooks[0][HookManager.active_hooks[0]] is True |
| 129 | + |
| 130 | + # Only unregister one session |
| 131 | + HookManager.register(LaunchHook(), 1) |
| 132 | + HookManager.unregister(0) |
| 133 | + assert len(HookManager.active_hooks) == 1 |
| 134 | + HookManager.unregister(1) |
| 135 | + assert len(HookManager.active_hooks) == 0 |
| 136 | + |
| 137 | + # Heterogenous hooks |
| 138 | + HookManager.register(InstrumentationHook(""), 2) |
| 139 | + HookManager.register(LaunchHook(), 2) |
| 140 | + assert len(HookManager.active_hooks) == 2 |
| 141 | + # Launch hook has a higher priority |
| 142 | + assert isinstance(HookManager.active_hooks[0], LaunchHook) |
| 143 | + assert isinstance(HookManager.active_hooks[1], InstrumentationHook) |
| 144 | + assert HookManager.session_hooks[2][HookManager.active_hooks[0]] is True |
| 145 | + assert HookManager.session_hooks[2][HookManager.active_hooks[1]] is True |
| 146 | + HookManager.unregister() |
| 147 | + assert len(HookManager.active_hooks) == 0 |
| 148 | + |
| 149 | + |
120 | 150 | def test_scope_metrics(tmp_path: pathlib.Path): |
121 | 151 | temp_file = tmp_path / "test_scope_metrics.hatchet" |
122 | 152 | session_id = proton.start(str(temp_file.with_suffix(""))) |
|
0 commit comments