@@ -236,17 +236,7 @@ def run_iteration():
236
236
"with_stack" : False ,
237
237
}
238
238
239
- for _ in range (n_profiler_runs ):
240
- # Profile execution
241
- with torch .profiler .profile (** profiler_config ) as prof :
242
- if use_cudagraph :
243
- g .replay ()
244
- else :
245
- # Execute multiple iterations for regular mode
246
- for _ in range (iterations_per_profiler_run ):
247
- run_iteration ()
248
- torch .cuda .synchronize ()
249
-
239
+ def _trace_handler (prof : torch .profiler .profile ) -> None :
250
240
# Collect all kernel execution intervals
251
241
kernel_intervals = []
252
242
@@ -299,10 +289,23 @@ def run_iteration():
299
289
)
300
290
301
291
# Convert to milliseconds and normalize by iterations
302
- total_kernel_time_ms = (
292
+ kernel_time_per_iteration_ms = (
303
293
total_kernel_time_us / 1000.0
304
294
) / iterations_per_profiler_run
305
- all_kernel_times .append (total_kernel_time_ms )
295
+ all_kernel_times .append (kernel_time_per_iteration_ms )
296
+
297
+ for _ in range (n_profiler_runs ):
298
+ # Profile execution
299
+ with torch .profiler .profile (
300
+ ** profiler_config , on_trace_ready = _trace_handler
301
+ ) as prof :
302
+ if use_cudagraph :
303
+ g .replay ()
304
+ else :
305
+ # Execute multiple iterations for regular mode
306
+ for _ in range (iterations_per_profiler_run ):
307
+ run_iteration ()
308
+ torch .cuda .synchronize ()
306
309
307
310
times = torch .tensor (all_kernel_times , dtype = torch .float )
308
311
return _summarize_statistics (times , quantiles = None , return_mode = return_mode )
0 commit comments