Skip to content

Commit 84bdc4d

Browse files
committed
[WIP] Improve autotune infra to catch more error cases
1 parent 41fe6e9 commit 84bdc4d

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

helion/_compiler/tile_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
9494
reduction_loop = env.config_spec.reduction_loops.config_get(
9595
config.reduction_loops, block_id, None
9696
)
97-
if reduction_loop is None:
97+
if reduction_loop is None or reduction_loop <= 1:
9898
strategy: TileStrategy = PersistentReductionStrategy(fn, block_id)
9999
else:
100100
strategy = LoopedReductionStrategy(fn, block_id, reduction_loop)

helion/autotuner/base_search.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from torch._inductor.runtime.triton_compat import OutOfResources
2323
from torch._inductor.runtime.triton_compat import PTXASError
24+
from triton.compiler.errors import CompilationError
2425
import torch.multiprocessing as mp
2526
from triton.testing import do_bench
2627

@@ -43,7 +44,7 @@
4344
from . import ConfigSpec
4445

4546
_expected_errors_regexp: re.Pattern[str] = re.compile(
46-
r"|".join(map(re.escape, ["[CUDA]: invalid argument"]))
47+
r"|".join(map(re.escape, ["[CUDA]: invalid argument", "exceeds triton maximum tensor numel"]))
4748
)
4849

4950

@@ -88,10 +89,13 @@ def benchmark(self, config: Config) -> float:
8889
Returns:
8990
The performance of the configuration in seconds.
9091
"""
91-
fn = self.kernel.compile_config(config, allow_print=False)
92-
if self.start_precompile_and_check_for_hangs(config, fn)():
93-
return self.benchmark_function(config, fn)
94-
return inf
92+
try:
93+
fn = self.kernel.compile_config(config, allow_print=False)
94+
if self.start_precompile_and_check_for_hangs(config, fn)():
95+
return self.benchmark_function(config, fn)
96+
return inf
97+
except Exception as e:
98+
return inf
9599

96100
def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
97101
"""
@@ -125,8 +129,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
125129
self.log.debug("Benchmarking failed: OutOfResources")
126130
except PTXASError:
127131
self.log.warning(f"PTXASError compiling config: {config}")
132+
except CompilationError:
133+
self.log.debug("Benchmarking failed: Triton CompilationError")
128134
except Exception as e:
129-
if not _expected_errors_regexp.search(str(e)):
135+
if not _expected_errors_regexp.search(str(e)) and not "exceeds triton maximum tensor numel" in str(e):
130136
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
131137
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
132138
return inf
@@ -149,6 +155,8 @@ def start_precompile_and_check_for_hangs(
149155
"""
150156
if not self.settings.autotune_precompile:
151157
return PrecompileFuture.skip(self, config, True)
158+
if fn is None:
159+
return PrecompileFuture.skip(self, config, False)
152160
ctx = mp.get_context("fork")
153161

154162
def extract_launcher(
@@ -188,7 +196,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
188196
Returns:
189197
A list of tuples containing configurations and their performance.
190198
"""
191-
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
199+
fns = []
200+
for c in configs:
201+
try:
202+
compile_result = self.kernel.compile_config(c, allow_print=False)
203+
fns.append(compile_result)
204+
except Exception as e:
205+
fns.append(None)
192206
if self.settings.autotune_precompile:
193207
is_workings = PrecompileFuture.wait_for_all(
194208
[

helion/autotuner/config_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ def _flat_config(
411411
default = min(high, 4096)
412412
value = fn(BlockSizeFragment(low, high, default))
413413
assert isinstance(value, int)
414-
if value >= self.size_hint:
415-
return None # max size becomes persistent reduction
414+
if value >= self.size_hint or value < low:
415+
return None # max size or invalid value becomes persistent reduction
416416
return value
417417

418418
def _normalize(self, name: str, value: object) -> int | None:

0 commit comments

Comments
 (0)