16
16
from typing import NamedTuple
17
17
from typing import NoReturn
18
18
19
+ from triton .compiler .errors import CompilationError
20
+
19
21
if TYPE_CHECKING :
20
22
from triton .runtime .jit import JITFunction
21
23
@@ -97,10 +99,13 @@ def benchmark(self, config: Config) -> float:
97
99
Returns:
98
100
The performance of the configuration in seconds.
99
101
"""
100
- fn = self .kernel .compile_config (config , allow_print = False )
101
- if self .start_precompile_and_check_for_hangs (config , fn )():
102
- return self .benchmark_function (config , fn )
103
- return inf
102
+ try :
103
+ fn = self .kernel .compile_config (config , allow_print = False )
104
+ if self .start_precompile_and_check_for_hangs (config , fn )():
105
+ return self .benchmark_function (config , fn )
106
+ return inf
107
+ except Exception as e :
108
+ return inf
104
109
105
110
def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
106
111
"""
@@ -134,9 +139,11 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
134
139
self .log .debug ("Benchmarking failed: OutOfResources" )
135
140
except PTXASError :
136
141
self .log .warning (f"PTXASError compiling config: { config } " )
142
+ except CompilationError :
143
+ self .log .debug ("Benchmarking failed: Triton CompilationError" )
137
144
except Exception as e :
138
- if not _expected_errors_regexp .search (str (e )):
139
- raise exc .TritonError (f"{ type (e ).__qualname__ } : { e } " , config ) from e
145
+ # if not _expected_errors_regexp.search(str(e)):
146
+ # raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
140
147
self .log .debug (f"Benchmarking failed: { type (e ).__name__ } : { e } " )
141
148
return inf
142
149
@@ -158,6 +165,8 @@ def start_precompile_and_check_for_hangs(
158
165
"""
159
166
if not self .settings .autotune_precompile :
160
167
return PrecompileFuture .skip (self , config , True )
168
+ if fn is None :
169
+ return PrecompileFuture .skip (self , config , False )
161
170
ctx = mp .get_context ("fork" )
162
171
163
172
def extract_launcher (
@@ -178,6 +187,8 @@ def extract_launcher(
178
187
precompiler = make_precompiler (e .kernel )(* e .args , ** e .kwargs )
179
188
if precompiler is already_compiled :
180
189
return PrecompileFuture .skip (self , config , True )
190
+ except Exception as e :
191
+ return PrecompileFuture .skip (self , config , False )
181
192
process : mp .Process = ctx .Process (target = precompiler ) # pyright: ignore[reportAssignmentType]
182
193
process .start ()
183
194
return PrecompileFuture (
@@ -197,7 +208,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
197
208
Returns:
198
209
A list of tuples containing configurations and their performance.
199
210
"""
200
- fns = [self .kernel .compile_config (c , allow_print = False ) for c in configs ]
211
+ fns = []
212
+ for c in configs :
213
+ try :
214
+ compile_result = self .kernel .compile_config (c , allow_print = False )
215
+ fns .append (compile_result )
216
+ except Exception as e :
217
+ fns .append (None )
201
218
if self .settings .autotune_precompile :
202
219
is_workings = PrecompileFuture .wait_for_all (
203
220
[
@@ -376,11 +393,12 @@ def population_statistics(population: list[PopulationMember]) -> str:
376
393
working = [x for x in population if not math .isinf (x .perf )]
377
394
return (
378
395
f"failed={ len (population ) - len (working )} "
396
+ ) + (
379
397
f"min={ working [0 ].perf :.4f} "
380
398
f"mid={ working [len (working ) // 2 ].perf :.4f} "
381
399
f"max={ working [- 1 ].perf :.4f} "
382
400
f"best={ population [0 ].config !s} "
383
- )
401
+ ) if len ( working ) > 0 else "all failed!"
384
402
return (
385
403
f"min={ population [0 ].perf :.4f} "
386
404
f"mid={ population [len (population ) // 2 ].perf :.4f} "
0 commit comments