@@ -149,7 +149,7 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
149149
150150
151151def do_bench_upstream_pytorch_profiler (fn , warmup = 25 , rep = 100 , grad_to_none = None , quantiles = None , return_mode = "mean" ,
152- device = "xpu" , sync_submitting = True , kernel_name = None ):
152+ device = "xpu" , sync_submitting = True , kernel_name = None ): # pylint: disable=unused-argument
153153 """
154154 Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
155155 the 20-th and 80-th performance percentile.
@@ -168,7 +168,7 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
168168
169169 assert return_mode in ["min" , "max" , "mean" , "median" ]
170170 import torch
171- from torch .profiler import profile , ProfilerActivity
171+ from torch .profiler import profile , ProfilerActivity , record_function
172172
173173 fn ()
174174 synchronize ()
@@ -210,22 +210,24 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
210210 if sync_submitting :
211211 synchronize ()
212212 # record time of `fn`
213- fn ()
213+ with record_function ("__profile_kernel_of_func" ):
214+ fn ()
214215 # Record clocks
215216 synchronize ()
216217
217- function_events = prof .events ()
218+ profiling_func_filter = filter (lambda x : x .name .startswith ("__profile_kernel_of_func" ), prof .events ())
219+ functions = list (profiling_func_filter )
218220
219- functions = []
220- if isinstance (kernel_name , str ):
221- kernel_name = [kernel_name ]
222- for ker_name in kernel_name :
223- functions .extend (list (filter (lambda x : x .name .startswith (ker_name ), function_events ))) # pylint: disable=cell-var-from-loop
224- # profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
221+ def extract_kernels (funcs ):
222+ kernels = []
223+ kernels += list (itertools .chain .from_iterable (map (lambda func : extract_kernels (func .cpu_children ), funcs )))
224+ kernels += list (itertools .chain .from_iterable ([func .kernels for func in funcs ]))
225+ return kernels
225226
226- assert len (functions ) == n_repeat , f"the profiling number not match, { len (functions )} "
227+ kernels = [extract_kernels (func .cpu_children ) for func in functions ]
228+ assert len (kernels ) == n_repeat , "the profiling number not match"
227229 # Make the time to the milliseconds.
228- times = torch .tensor ([f . self_device_time_total * 1e-3 for f in functions ], dtype = torch .float )
230+ times = torch .tensor ([sum ([ k . duration for k in ks ]) * 1e-3 for ks in kernels ], dtype = torch .float )
229231 return _summarize_statistics (times , quantiles , return_mode )
230232
231233
0 commit comments