21
21
22
22
from torch ._inductor .runtime .triton_compat import OutOfResources
23
23
from torch ._inductor .runtime .triton_compat import PTXASError
24
+ from triton .compiler .errors import CompilationError
24
25
import torch .multiprocessing as mp
25
26
from triton .testing import do_bench
26
27
43
44
from . import ConfigSpec
44
45
45
46
_expected_errors_regexp : re .Pattern [str ] = re .compile (
46
- r"|" .join (map (re .escape , ["[CUDA]: invalid argument" ]))
47
+ r"|" .join (map (re .escape , ["[CUDA]: invalid argument" , "exceeds triton maximum tensor numel" ]))
47
48
)
48
49
49
50
@@ -88,10 +89,13 @@ def benchmark(self, config: Config) -> float:
88
89
Returns:
89
90
The performance of the configuration in seconds.
90
91
"""
91
- fn = self .kernel .compile_config (config , allow_print = False )
92
- if self .start_precompile_and_check_for_hangs (config , fn )():
93
- return self .benchmark_function (config , fn )
94
- return inf
92
+ try :
93
+ fn = self .kernel .compile_config (config , allow_print = False )
94
+ if self .start_precompile_and_check_for_hangs (config , fn )():
95
+ return self .benchmark_function (config , fn )
96
+ return inf
97
+ except Exception as e :
98
+ return inf
95
99
96
100
def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
97
101
"""
@@ -125,8 +129,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
125
129
self .log .debug ("Benchmarking failed: OutOfResources" )
126
130
except PTXASError :
127
131
self .log .warning (f"PTXASError compiling config: { config } " )
132
+ except CompilationError :
133
+ self .log .debug ("Benchmarking failed: Triton CompilationError" )
128
134
except Exception as e :
129
- if not _expected_errors_regexp .search (str (e )):
135
+ if not _expected_errors_regexp .search (str (e )) and not "exceeds triton maximum tensor numel" in str ( e ) :
130
136
raise exc .TritonError (f"{ type (e ).__qualname__ } : { e } " , config ) from e
131
137
self .log .debug (f"Benchmarking failed: { type (e ).__name__ } : { e } " )
132
138
return inf
@@ -149,6 +155,8 @@ def start_precompile_and_check_for_hangs(
149
155
"""
150
156
if not self .settings .autotune_precompile :
151
157
return PrecompileFuture .skip (self , config , True )
158
+ if fn is None :
159
+ return PrecompileFuture .skip (self , config , False )
152
160
ctx = mp .get_context ("fork" )
153
161
154
162
def extract_launcher (
@@ -188,7 +196,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
188
196
Returns:
189
197
A list of tuples containing configurations and their performance.
190
198
"""
191
- fns = [self .kernel .compile_config (c , allow_print = False ) for c in configs ]
199
+ fns = []
200
+ for c in configs :
201
+ try :
202
+ compile_result = self .kernel .compile_config (c , allow_print = False )
203
+ fns .append (compile_result )
204
+ except Exception as e :
205
+ fns .append (None )
192
206
if self .settings .autotune_precompile :
193
207
is_workings = PrecompileFuture .wait_for_all (
194
208
[
0 commit comments