Skip to content

Commit 43625fc

Browse files
authored
[Runtime] Make set_allocator thread safe (#7685)
This changes the allocator state to a `ContextVar` which is thread-local, so calling it from different threads won't cause a race. Note that this does mean you won't inherit allocators in child threads, and instead would need to set it in each thread independently. For most use cases this is probably fine though.
1 parent 67af519 commit 43625fc

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

python/triton/runtime/_allocation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional, Protocol
2+
from contextvars import ContextVar
23

34

45
class Buffer(Protocol):
@@ -20,7 +21,7 @@ def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
2021
"Use triton.set_allocator to specify an allocator.")
2122

2223

23-
_allocator: Allocator = NullAllocator()
24+
_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
2425

2526

2627
def set_allocator(allocator: Allocator):
@@ -29,4 +30,4 @@ def set_allocator(allocator: Allocator):
2930
require additional global memory workspace.
3031
"""
3132
global _allocator
32-
_allocator = allocator
33+
_allocator.set(allocator)

third_party/nvidia/backend/driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
702702
if self.global_scratch_size > 0:
703703
grid_size = gridX * gridY * gridZ
704704
alloc_size = grid_size * self.num_ctas * self.global_scratch_size
705-
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
705+
alloc_fn = _allocation._allocator.get()
706+
global_scratch = alloc_fn(alloc_size, self.global_scratch_align, stream)
706707
else:
707708
global_scratch = None
708709
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,

0 commit comments

Comments
 (0)