Skip to content

Commit cf0db92

Browse files
authored
[Runtime] Make AsyncCompileMode thread safe (#7701)
In the same vein as #7685 We rely on the global variable `active_mode`, so if you were to use `AsyncCompileMode` in two different threads they may race with each other.
1 parent dd26258 commit cf0db92

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

python/triton/runtime/_allocation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ def set_allocator(allocator: Allocator):
2929
The allocator function is called during kernel launch for kernels that
3030
require additional global memory workspace.
3131
"""
32-
global _allocator
3332
_allocator.set(allocator)

python/triton/runtime/_async_compile.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22
from typing import Callable, Optional
33
from concurrent.futures import Executor, as_completed, Future
4+
from contextvars import ContextVar
45

5-
active_mode: Optional[AsyncCompileMode] = None
6+
active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
67

78

89
class FutureKernel:
@@ -42,14 +43,13 @@ def submit(self, key, compile_fn, finalize_fn):
4243
return future_kernel
4344

4445
def __enter__(self):
45-
global active_mode
46-
if active_mode is not None:
46+
if active_mode.get() is not None:
4747
raise RuntimeError("Another AsyncCompileMode is already active")
48-
active_mode = self
48+
active_mode.set(self)
49+
return self
4950

5051
def __exit__(self, exc_type, exc_value, traceback):
51-
global active_mode
5252
# Finalize any outstanding compiles
5353
for future in as_completed(self.raw_futures):
5454
self.future_kernels[future._key].result()
55-
active_mode = None
55+
active_mode.set(None)

python/triton/runtime/jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup
757757
return None
758758
src = self.ASTSource(self, signature, constexprs, attrs)
759759

760-
async_mode = _async_compile.active_mode
760+
async_mode = _async_compile.active_mode.get()
761761
if async_mode is not None:
762762

763763
env_vars = get_cache_invalidating_env_vars()

0 commit comments

Comments
 (0)