Skip to content

Commit 1254163

Browse files
committed
Fix upstream profiler for several kernels
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 4355afd commit 1254163

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
213213

214214
function_events = prof.events()
215215

216-
functions = []
216+
all_functions = []
217217
if isinstance(kernel_name, str):
218218
kernel_name = [kernel_name]
219219
for ker_name in kernel_name:
220-
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
220+
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
221+
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
222+
all_functions.append(functions)
221223
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
222224

223-
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
224225
# Make the time to the milliseconds.
225-
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
226+
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total), f) * 1e-3 for f in all_functions],
227+
dtype=torch.float)
226228
return _summarize_statistics(times, quantiles, return_mode)
227229

228230

0 commit comments

Comments
 (0)