Skip to content

Commit 9c7371b

Browse files
committed
try to catch more errors
1 parent 07a5b9b commit 9c7371b

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

helion/autotuner/base_search.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from typing import NamedTuple
1717
from typing import NoReturn
1818

19+
from triton.compiler.errors import CompilationError
20+
1921
if TYPE_CHECKING:
2022
from triton.runtime.jit import JITFunction
2123

@@ -97,10 +99,13 @@ def benchmark(self, config: Config) -> float:
9799
Returns:
98100
The performance of the configuration in seconds.
99101
"""
100-
fn = self.kernel.compile_config(config, allow_print=False)
101-
if self.start_precompile_and_check_for_hangs(config, fn)():
102-
return self.benchmark_function(config, fn)
103-
return inf
102+
try:
103+
fn = self.kernel.compile_config(config, allow_print=False)
104+
if self.start_precompile_and_check_for_hangs(config, fn)():
105+
return self.benchmark_function(config, fn)
106+
return inf
107+
except Exception as e:
108+
return inf
104109

105110
def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
106111
"""
@@ -134,9 +139,11 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
134139
self.log.debug("Benchmarking failed: OutOfResources")
135140
except PTXASError:
136141
self.log.warning(f"PTXASError compiling config: {config}")
142+
except CompilationError:
143+
self.log.debug("Benchmarking failed: Triton CompilationError")
137144
except Exception as e:
138-
if not _expected_errors_regexp.search(str(e)):
139-
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
145+
# if not _expected_errors_regexp.search(str(e)):
146+
# raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
140147
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
141148
return inf
142149

@@ -158,6 +165,8 @@ def start_precompile_and_check_for_hangs(
158165
"""
159166
if not self.settings.autotune_precompile:
160167
return PrecompileFuture.skip(self, config, True)
168+
if fn is None:
169+
return PrecompileFuture.skip(self, config, False)
161170
ctx = mp.get_context("fork")
162171

163172
def extract_launcher(
@@ -178,6 +187,8 @@ def extract_launcher(
178187
precompiler = make_precompiler(e.kernel)(*e.args, **e.kwargs)
179188
if precompiler is already_compiled:
180189
return PrecompileFuture.skip(self, config, True)
190+
except Exception as e:
191+
return PrecompileFuture.skip(self, config, False)
181192
process: mp.Process = ctx.Process(target=precompiler) # pyright: ignore[reportAssignmentType]
182193
process.start()
183194
return PrecompileFuture(
@@ -197,7 +208,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
197208
Returns:
198209
A list of tuples containing configurations and their performance.
199210
"""
200-
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
211+
fns = []
212+
for c in configs:
213+
try:
214+
compile_result = self.kernel.compile_config(c, allow_print=False)
215+
fns.append(compile_result)
216+
except Exception as e:
217+
fns.append(None)
201218
if self.settings.autotune_precompile:
202219
is_workings = PrecompileFuture.wait_for_all(
203220
[

helion/autotuner/config_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ def _flat_config(
411411
default = min(high, 4096)
412412
value = fn(BlockSizeFragment(low, high, default))
413413
assert isinstance(value, int)
414-
if value >= self.size_hint:
415-
return None # max size becomes persistent reduction
414+
if value >= self.size_hint or value < low:
415+
return None # max size or invalid value becomes persistent reduction
416416
return value
417417

418418
def _normalize(self, name: str, value: object) -> int | None:

0 commit comments

Comments
 (0)