1212CACHE_CLEAR_KERNEL = "void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<int>, std::array<char*, 1ul> >(int, at::native::FillFunctor<int>, std::array<char*, 1ul>)"
1313
1414
15- def _kineto_events_to_latency (prof ):
15+ def _kineto_events_to_latency (prof , n_repeat ):
1616 prof_averages = prof .key_averages (group_by_input_shape = False )
1717 cuda_event_names = [
1818 event .key
@@ -33,22 +33,16 @@ def _kineto_events_to_latency(prof):
3333 kernel_duration_name_map [event .name ()] = []
3434 kernel_duration_name_map [event .name ()].append (event .duration_ns () / 1e6 )
3535
36- kernel_hits = [len (kernel_duration_name_map [k ]) for k in kernel_duration_name_map ]
37- assert all (
38- x == kernel_hits [0 ] for x in kernel_hits
39- ), "Error: Not all kernels run the same time."
36+ op_time = 0.0
37+ for name in kernel_duration_name_map :
38+ op_time += sum (kernel_duration_name_map [name ])
4039
41- op_latencies = []
42- for x in range (kernel_hits [0 ]):
43- op_time = 0.0
44- for name in kernel_duration_name_map :
45- op_time += kernel_duration_name_map [name ][x ]
46- op_latencies .append (op_time )
40+ op_time = op_time / n_repeat
4741
4842 print (
4943 prof .key_averages (group_by_input_shape = False ).table (sort_by = "cuda_time_total" )
5044 )
51- return Latency ( times = op_latencies )
45+ return op_time
5246
5347
5448def _do_bench_cuda_time_cudagraph (
@@ -59,7 +53,7 @@ def _do_bench_cuda_time_cudagraph(
5953 n_repeat : int ,
6054 grad_to_none : bool ,
6155 bypass_fail : bool = False ,
62- ) -> Latency :
56+ ) -> float :
6357 with torch .cuda .stream (torch .cuda .Stream ()):
6458 g = torch .cuda .CUDAGraph ()
6559 with torch .cuda .graph (g ):
@@ -87,7 +81,7 @@ def _do_bench_cuda_time_cudagraph(
8781 prof .step ()
8882 synchronize_with_timing ()
8983
90- return _kineto_events_to_latency (prof )
84+ return _kineto_events_to_latency (prof , n_repeat )
9185
9286
9387def do_bench_cuda_time (
@@ -97,7 +91,7 @@ def do_bench_cuda_time(
9791 grad_to_none : bool ,
9892 use_cuda_graphs : bool = False ,
9993 bypass_fail : bool = False ,
100- ) -> Latency :
94+ ) -> float :
10195 """
10296 Return the aggregated CUDA time of a benchmarked operator backend.
10397 """
@@ -156,4 +150,4 @@ def synchronize_with_timing():
156150 prof .step ()
157151 synchronize_with_timing ()
158152
159- return _kineto_events_to_latency (prof )
153+ return _kineto_events_to_latency (prof , n_repeat )
0 commit comments