17
17
from typing import NamedTuple
18
18
from typing import NoReturn
19
19
20
+ from triton .compiler .errors import CompilationError
21
+
20
22
if TYPE_CHECKING :
21
23
from triton .runtime .jit import JITFunction
22
24
@@ -108,10 +110,13 @@ def benchmark(self, config: Config) -> float:
108
110
Returns:
109
111
The performance of the configuration in seconds.
110
112
"""
111
- fn = self .kernel .compile_config (config , allow_print = False )
112
- if self .start_precompile_and_check_for_hangs (config , fn )():
113
- return self .benchmark_function (config , fn )
114
- return inf
113
+ try :
114
+ fn = self .kernel .compile_config (config , allow_print = False )
115
+ if self .start_precompile_and_check_for_hangs (config , fn )():
116
+ return self .benchmark_function (config , fn )
117
+ return inf
118
+ except Exception as e :
119
+ return inf
115
120
116
121
def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
117
122
"""
@@ -145,9 +150,11 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
145
150
self .log .debug ("Benchmarking failed: OutOfResources" )
146
151
except PTXASError :
147
152
self .log .warning (f"PTXASError compiling config: { config } " )
153
+ except CompilationError :
154
+ self .log .debug ("Benchmarking failed: Triton CompilationError" )
148
155
except Exception as e :
149
- if not _expected_errors_regexp .search (str (e )):
150
- raise exc .TritonError (f"{ type (e ).__qualname__ } : { e } " , config ) from e
156
+ # if not _expected_errors_regexp.search(str(e)):
157
+ # raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
151
158
self .log .debug (f"Benchmarking failed: { type (e ).__name__ } : { e } " )
152
159
return inf
153
160
@@ -169,6 +176,8 @@ def start_precompile_and_check_for_hangs(
169
176
"""
170
177
if not self .settings .autotune_precompile :
171
178
return PrecompileFuture .skip (self , config , True )
179
+ if fn is None :
180
+ return PrecompileFuture .skip (self , config , False )
172
181
ctx = mp .get_context ("fork" )
173
182
174
183
def extract_launcher (
@@ -189,6 +198,8 @@ def extract_launcher(
189
198
precompiler = make_precompiler (e .kernel )(* e .args , ** e .kwargs )
190
199
if precompiler is already_compiled :
191
200
return PrecompileFuture .skip (self , config , True )
201
+ except Exception as e :
202
+ return PrecompileFuture .skip (self , config , False )
192
203
process : mp .Process = ctx .Process (target = precompiler ) # pyright: ignore[reportAssignmentType]
193
204
process .start ()
194
205
return PrecompileFuture (
@@ -208,7 +219,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
208
219
Returns:
209
220
A list of tuples containing configurations and their performance.
210
221
"""
211
- fns = [self .kernel .compile_config (c , allow_print = False ) for c in configs ]
222
+ fns = []
223
+ for c in configs :
224
+ try :
225
+ compile_result = self .kernel .compile_config (c , allow_print = False )
226
+ fns .append (compile_result )
227
+ except Exception as e :
228
+ fns .append (None )
212
229
if self .settings .autotune_precompile :
213
230
is_workings = PrecompileFuture .wait_for_all (
214
231
[
@@ -387,11 +404,12 @@ def population_statistics(population: list[PopulationMember]) -> str:
387
404
working = [x for x in population if not math .isinf (x .perf )]
388
405
return (
389
406
f"failed={ len (population ) - len (working )} "
407
+ ) + (
390
408
f"min={ working [0 ].perf :.4f} "
391
409
f"mid={ working [len (working ) // 2 ].perf :.4f} "
392
410
f"max={ working [- 1 ].perf :.4f} "
393
411
f"best={ population [0 ].config !s} "
394
- )
412
+ ) if len ( working ) > 0 else "all failed!"
395
413
return (
396
414
f"min={ population [0 ].perf :.4f} "
397
415
f"mid={ population [len (population ) // 2 ].perf :.4f} "
0 commit comments