@@ -166,6 +166,60 @@ def _do_bench_inductor(fn, warmup, rep, return_mode="all", grad_to_none=None):
166
166
return _summarize_statistics (times , quantiles = None , return_mode = return_mode )
167
167
168
168
169
+ def _do_bench_cudagraph_with_cache_clear (
170
+ fn , rep = 20 , grad_to_none = None , quantiles = None , return_mode = "mean"
171
+ ):
172
+ """Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing."""
173
+ assert return_mode in ["min" , "max" , "mean" , "median" , "all" ]
174
+
175
+ cache = triton .runtime .driver .active .get_empty_cache_for_benchmark ()
176
+
177
+ with torch .cuda .stream (torch .cuda .Stream ()):
178
+ cache .zero_ ()
179
+ fn ()
180
+ if grad_to_none is not None :
181
+ for x in grad_to_none :
182
+ x .detach_ ()
183
+ x .requires_grad_ (True )
184
+ x .grad = None
185
+
186
+ start_event = torch .cuda .Event (enable_timing = True )
187
+ end_event = torch .cuda .Event (enable_timing = True )
188
+ start_event .record ()
189
+ for _ in range (5 ):
190
+ cache .zero_ ()
191
+ fn ()
192
+ end_event .record ()
193
+ torch .cuda .synchronize ()
194
+ estimate_ms = start_event .elapsed_time (end_event ) / 5
195
+
196
+ n_repeat = 1000 if estimate_ms == 0 else max (1 , int (rep / estimate_ms ))
197
+
198
+ g = torch .cuda .CUDAGraph ()
199
+ with torch .cuda .graph (g ):
200
+ for _ in range (n_repeat ):
201
+ if grad_to_none is not None :
202
+ for x in grad_to_none :
203
+ x .grad = None
204
+ cache .zero_ ()
205
+ fn ()
206
+ torch .cuda .synchronize ()
207
+
208
+ ret = []
209
+ n_retries = 10
210
+ for _ in range (n_retries ):
211
+ start_event = torch .cuda .Event (enable_timing = True )
212
+ end_event = torch .cuda .Event (enable_timing = True )
213
+ start_event .record ()
214
+ g .replay ()
215
+ end_event .record ()
216
+ torch .cuda .synchronize ()
217
+ ret .append (start_event .elapsed_time (end_event ) / n_repeat )
218
+
219
+ times = torch .tensor (ret , dtype = torch .float )
220
+ return _summarize_statistics (times , quantiles , return_mode )
221
+
222
+
169
223
def _do_bench_profiler (
170
224
fn , warmup , rep , return_mode = "all" , grad_to_none = None , use_cudagraph = False
171
225
):
@@ -383,7 +437,7 @@ def do_bench_wrapper(
383
437
if latency_measure_mode == "profiler" :
384
438
bench_fn = partial (_do_bench_profiler , warmup = 1 , use_cudagraph = True )
385
439
else :
386
- bench_fn = triton . testing . do_bench_cudagraph
440
+ bench_fn = _do_bench_cudagraph_with_cache_clear
387
441
388
442
return Latency (
389
443
times = bench_fn (
0 commit comments