Skip to content

Commit 227cf9f

Browse files
authored
Better fix for triton allocator error (#427)
1 parent 4718678 commit 227cf9f

27 files changed

+19
-909
lines changed

helion/_compiler/device_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def sorted_args(self) -> list[Argument]:
419419

420420
def codegen_function_def(self) -> list[ast.stmt]:
421421
prefix = []
422-
if CompileEnvironment.current().settings.set_triton_allocator:
422+
if self._tensor_descriptor_args:
423423
prefix.append(
424424
statement_from_string("helion.runtime.set_triton_allocator()")
425425
)

helion/runtime/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import contextvars
34
import functools
5+
from typing import TYPE_CHECKING
46

57
import torch
6-
import triton
78

89
from .config import Config as Config
910
from .kernel import Kernel as Kernel
@@ -12,15 +13,29 @@
1213
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
1314
from .triton_helpers import triton_wait_signal as triton_wait_signal
1415

16+
if TYPE_CHECKING:
17+
import triton
18+
1519

1620
def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
1721
return torch.empty(size, device="cuda", dtype=torch.int8)
1822

1923

2024
@functools.cache
2125
def set_triton_allocator() -> None:
22-
if hasattr(triton, "set_allocator"):
23-
triton.set_allocator(_alloc_fn)
26+
try:
27+
from triton import set_allocator
28+
from triton.runtime._allocation import NullAllocator
29+
from triton.runtime._allocation import _allocator
30+
except ImportError:
31+
return
32+
if isinstance(_allocator, contextvars.ContextVar):
33+
existing = _allocator.get()
34+
else: # older versions of Triton
35+
existing = _allocator
36+
# if allocator isn't NullAllocator, we assume it is set by the user
37+
if isinstance(existing, NullAllocator):
38+
set_allocator(_alloc_fn)
2439

2540

2641
def get_num_sm(device: torch.device) -> int:

0 commit comments

Comments
 (0)