Skip to content

Commit dabeb36

Browse files
authored
Always set triton allocator (#416)
1 parent aabe823 commit dabeb36

27 files changed

+882
-10
lines changed

helion/_compiler/device_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def sorted_args(self) -> list[Argument]:
419419

420420
def codegen_function_def(self) -> list[ast.stmt]:
421421
prefix = []
422-
if self._tensor_descriptor_args:
422+
if CompileEnvironment.current().settings.set_triton_allocator:
423423
prefix.append(
424424
statement_from_string("helion.runtime.set_triton_allocator()")
425425
)

helion/runtime/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,8 @@ def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
1919

2020
@functools.cache
2121
def set_triton_allocator() -> None:
22-
try:
23-
from triton.runtime._allocation import NullAllocator
24-
from triton.runtime._allocation import _allocator
25-
26-
if not isinstance(_allocator, NullAllocator):
27-
return
28-
except ImportError:
29-
pass
30-
triton.set_allocator(_alloc_fn)
22+
if hasattr(triton, "set_allocator"):
23+
triton.set_allocator(_alloc_fn)
3124

3225

3326
def get_num_sm(device: torch.device) -> int:

helion/runtime/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class _Settings:
9595
RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF
9696
)
9797
autotuner_fn: AutotunerFunction = default_autotuner_fn
98+
set_triton_allocator: bool = True
9899

99100

100101
class Settings(_Settings):
@@ -117,6 +118,7 @@ class Settings(_Settings):
117118
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
118119
"ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.",
119120
"autotuner_fn": "Function to create an autotuner",
121+
"set_triton_allocator": "If True, insert helion.runtime.set_triton_allocator() call in generated code. Default is True.",
120122
}
121123
assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}
122124

0 commit comments

Comments
 (0)