16
16
from typing import NamedTuple
17
17
from typing import NoReturn
18
18
19
+ from triton .compiler .errors import CompilationError
20
+
19
21
if TYPE_CHECKING :
20
22
from triton .runtime .jit import JITFunction
21
23
@@ -97,10 +99,13 @@ def benchmark(self, config: Config) -> float:
97
99
Returns:
98
100
The performance of the configuration in seconds.
99
101
"""
100
- fn = self .kernel .compile_config (config , allow_print = False )
101
- if self .start_precompile_and_check_for_hangs (config , fn )():
102
- return self .benchmark_function (config , fn )
103
- return inf
102
+ try :
103
+ fn = self .kernel .compile_config (config , allow_print = False )
104
+ if self .start_precompile_and_check_for_hangs (config , fn )():
105
+ return self .benchmark_function (config , fn )
106
+ return inf
107
+ except Exception as e :
108
+ return inf
104
109
105
110
def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
106
111
"""
@@ -134,9 +139,11 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
134
139
self .log .debug ("Benchmarking failed: OutOfResources" )
135
140
except PTXASError :
136
141
self .log .warning (f"PTXASError compiling config: { config } " )
142
+ except CompilationError :
143
+ self .log .debug ("Benchmarking failed: Triton CompilationError" )
137
144
except Exception as e :
138
- if not _expected_errors_regexp .search (str (e )):
139
- raise exc .TritonError (f"{ type (e ).__qualname__ } : { e } " , config ) from e
145
+ # if not _expected_errors_regexp.search(str(e)):
146
+ # raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
140
147
self .log .debug (f"Benchmarking failed: { type (e ).__name__ } : { e } " )
141
148
return inf
142
149
@@ -158,6 +165,8 @@ def start_precompile_and_check_for_hangs(
158
165
"""
159
166
if not self .settings .autotune_precompile :
160
167
return PrecompileFuture .skip (self , config , True )
168
+ if fn is None :
169
+ return PrecompileFuture .skip (self , config , False )
161
170
ctx = mp .get_context ("fork" )
162
171
163
172
def extract_launcher (
@@ -178,6 +187,8 @@ def extract_launcher(
178
187
precompiler = make_precompiler (e .kernel )(* e .args , ** e .kwargs )
179
188
if precompiler is already_compiled :
180
189
return PrecompileFuture .skip (self , config , True )
190
+ except Exception as e :
191
+ return PrecompileFuture .skip (self , config , False )
181
192
process : mp .Process = ctx .Process (target = precompiler ) # pyright: ignore[reportAssignmentType]
182
193
process .start ()
183
194
return PrecompileFuture (
@@ -197,7 +208,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
197
208
Returns:
198
209
A list of tuples containing configurations and their performance.
199
210
"""
200
- fns = [self .kernel .compile_config (c , allow_print = False ) for c in configs ]
211
+ fns = []
212
+ for c in configs :
213
+ try :
214
+ compile_result = self .kernel .compile_config (c , allow_print = False )
215
+ fns .append (compile_result )
216
+ except Exception as e :
217
+ fns .append (None )
201
218
if self .settings .autotune_precompile :
202
219
is_workings = PrecompileFuture .wait_for_all (
203
220
[
0 commit comments