@@ -205,18 +205,40 @@ def _do_bench_cudagraph_with_cache_clear(
205
205
fn ()
206
206
torch .cuda .synchronize ()
207
207
208
- ret = []
208
+ cache_clear_graph = torch .cuda .CUDAGraph ()
209
+ with torch .cuda .graph (cache_clear_graph ):
210
+ for _ in range (n_repeat ):
211
+ cache .zero_ ()
212
+ torch .cuda .synchronize ()
213
+
209
214
n_retries = 10
215
+ cache_clear_times = []
216
+ total_times = []
210
217
for _ in range (n_retries ):
218
+ cache_clear_start_event = torch .cuda .Event (enable_timing = True )
219
+ cache_clear_end_event = torch .cuda .Event (enable_timing = True )
220
+ cache_clear_start_event .record ()
221
+ cache_clear_graph .replay ()
222
+ cache_clear_end_event .record ()
223
+ torch .cuda .synchronize ()
224
+ cache_clear_times .append (
225
+ cache_clear_start_event .elapsed_time (cache_clear_end_event ) / n_repeat
226
+ )
227
+
211
228
start_event = torch .cuda .Event (enable_timing = True )
212
229
end_event = torch .cuda .Event (enable_timing = True )
213
230
start_event .record ()
214
231
g .replay ()
215
232
end_event .record ()
216
233
torch .cuda .synchronize ()
217
- ret .append (start_event .elapsed_time (end_event ) / n_repeat )
234
+ total_times .append (start_event .elapsed_time (end_event ) / n_repeat )
218
235
219
- times = torch .tensor (ret , dtype = torch .float )
236
+ all_kernel_times = []
237
+ for total_time , cache_clear_time in zip (total_times , cache_clear_times ):
238
+ kernel_time = total_time - cache_clear_time
239
+ all_kernel_times .append (kernel_time )
240
+
241
+ times = torch .tensor (all_kernel_times , dtype = torch .float )
220
242
return _summarize_statistics (times , quantiles , return_mode )
221
243
222
244
0 commit comments