Skip to content

Commit 2f332f5

Browse files
authored
POC: Enable upstream pytorch profiler for microbenchmarks (#2343)
Closes #2344 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 9bda03d commit 2f332f5

12 files changed

+159
-21
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ on:
1717
options:
1818
- PYTORCH_LEGACY_PROFILER_USING_IPEX
1919
- ELAPSED_TIME
20+
- UPSTREAM_PYTORCH_PROFILER
2021
default: PYTORCH_LEGACY_PROFILER_USING_IPEX
2122
schedule:
2223
- cron: "5 23 * * *"

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
if USE_IPEX_OPTION:
88
BENCHMARKING_METHOD = "PYTORCH_LEGACY_PROFILER_USING_IPEX"
99
else:
10-
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "ELAPSED_TIME")
10+
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
1111

1212

1313
def synchronize():
@@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode):
3737

3838

3939
def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
40-
device="xpu", sync_submitting=True):
40+
device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
4141
"""
4242
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
4343
the 20-th and 80-th performance percentile.
@@ -127,7 +127,7 @@ def extract_kernels(funcs):
127127

128128

129129
def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True,
130-
return_mode="mean", device="xpu"):
130+
return_mode="mean", device="xpu", kernel_name=None): # pylint: disable=unused-argument
131131
"""
132132
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
133133
the 20-th and 80-th performance percentile.
@@ -155,10 +155,98 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
155155
return _summarize_statistics(times, quantiles, return_mode)
156156

157157

158+
def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True,
159+
return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None):
160+
"""
161+
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
162+
the 20-th and 80-th performance percentile.
163+
164+
:param fn: Function to benchmark
165+
:type fn: Callable
166+
:param warmup: Warmup time (in ms)
167+
:type warmup: int
168+
:param rep: Repetition time (in ms)
169+
:type rep: int
170+
:param grad_to_none: Reset the gradient of the provided tensor to None
171+
:type grad_to_none: torch.tensor, optional
172+
:param quantiles: Performance percentile to return in addition to the median.
173+
:type quantiles: list[float]
174+
:param fast_flush: Use faster kernel to flush L2 between measurements
175+
:type fast_flush: bool
176+
"""
177+
178+
assert return_mode in ["min", "max", "mean", "median"]
179+
import torch
180+
from torch.profiler import profile, ProfilerActivity
181+
182+
fn()
183+
synchronize()
184+
185+
# We maintain a buffer of 256 MB that we clear
186+
# before each kernel call to make sure that the L2
187+
# doesn't contain any input data before the run
188+
cache_size = 256 * 1024 * 1024
189+
if fast_flush:
190+
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
191+
else:
192+
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device)
193+
194+
# Estimate the runtime of the function
195+
start_event = torch.xpu.Event(enable_timing=True)
196+
end_event = torch.xpu.Event(enable_timing=True)
197+
start_event.record()
198+
for _ in range(5):
199+
cache.zero_()
200+
fn()
201+
end_event.record()
202+
synchronize()
203+
estimate_ms = start_event.elapsed_time(end_event) / 5
204+
205+
# compute number of warmup and repeat
206+
n_warmup = max(1, int(warmup / estimate_ms))
207+
n_repeat = max(1, int(rep / estimate_ms))
208+
# Warm-up
209+
for _ in range(n_warmup):
210+
fn()
211+
# Benchmark
212+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
213+
for _ in range(n_repeat):
214+
# we don't want `fn` to accumulate gradient values
215+
# if it contains a backward pass. So we clear the
216+
# provided gradients
217+
if grad_to_none is not None:
218+
for x in grad_to_none:
219+
x.grad = None
220+
# we clear the L2 cache before each run
221+
cache.zero_()
222+
if sync_submitting:
223+
synchronize()
224+
# record time of `fn`
225+
fn()
226+
# Record clocks
227+
synchronize()
228+
229+
function_events = prof.events()
230+
231+
functions = []
232+
if isinstance(kernel_name, str):
233+
kernel_name = [kernel_name]
234+
for ker_name in kernel_name:
235+
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
236+
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
237+
238+
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
239+
# Make the time to the milliseconds.
240+
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
241+
return _summarize_statistics(times, quantiles, return_mode)
242+
243+
158244
if BENCHMARKING_METHOD == "PYTORCH_LEGACY_PROFILER_USING_IPEX":
159245
do_bench = do_bench_ipex
160246
elif BENCHMARKING_METHOD == "ELAPSED_TIME":
161247
do_bench = do_bench_elapsed_time
248+
elif BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
249+
do_bench = do_bench_upstream_pytorch_profiler
162250
else:
163251
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")
164252

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
257257
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
258258
atol = 1e-1 if N_CTX == 16384 else 1e-2
259259
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
260-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
260+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
261+
kernel_name='_attn_fwd')
261262

262263
elif provider == 'xetla':
263264
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
@@ -272,7 +273,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
272273
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
273274

274275
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
275-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
276+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
277+
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')
276278

277279
else:
278280
raise NotImplementedError(f'Unsupported provider {provider}')

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def benchmark(M, N, provider):
131131
triton_fn = lambda: softmax(x, out)
132132
torch_fn = lambda: torch.softmax(x, axis=-1)
133133
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
134-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
134+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
135+
kernel_name="softmax_kernel")
135136

136137
elif provider == "torch-jit":
137138
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
@@ -144,7 +145,17 @@ def benchmark(M, N, provider):
144145
xetla_fn = lambda: func(x, out, 0)
145146
torch_fn = lambda: torch.softmax(x, axis=-1)
146147
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
147-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
148+
kernels_name = {
149+
"softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0",
150+
"softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0",
151+
"softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0",
152+
"softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0",
153+
"softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0",
154+
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
155+
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
156+
}
157+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
158+
kernel_name=kernels_name[name])
148159

149160
else:
150161
raise NotImplementedError(f"Unsupported provider {provider}")

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def benchmark(B, M, N, K, provider):
262262
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
263263
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
264264
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
265-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
265+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
266+
kernel_name='matmul_kernel_with_block_pointers')
266267
elif provider == 'xetla':
267268
if B == 1:
268269
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
@@ -276,8 +277,37 @@ def benchmark(B, M, N, K, provider):
276277
func = getattr(xetla_kernel, name)
277278
xetla_fn = lambda: func(a, b, c, acc, cnt)
278279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280+
281+
kernels_name = {
282+
'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row',
283+
'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row',
284+
'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row',
285+
'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row',
286+
'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row',
287+
'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row',
288+
'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row',
289+
'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row',
290+
'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row',
291+
'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row',
292+
'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row',
293+
'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row',
294+
'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row',
295+
'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row',
296+
'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row',
297+
'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row',
298+
'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row',
299+
'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row',
300+
'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row',
301+
'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row',
302+
'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row',
303+
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
304+
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
305+
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
306+
}
307+
279308
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
280-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
309+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
310+
kernel_name=kernels_name[name])
281311
else:
282312
raise NotImplementedError(f'Unsupported provider {provider}')
283313

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,17 @@ def benchmark(B, M, N, K, provider):
266266
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
267267
if len(a.shape) == 3:
268268
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
269+
kernel_name = 'matmul_kernel_with_block_pointers_batched'
269270
else:
270271
assert len(a.shape) == 2, 'Expecting shape of length 2'
271272
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
273+
kernel_name = 'matmul_kernel_with_block_pointers'
272274
triton_fn = lambda: matmul(a, b, d, c)
273275
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
274276
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
275277
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
276-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
278+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
279+
kernel_name=kernel_name)
277280
else:
278281
raise NotImplementedError(f'Unsupported provider {provider}')
279282

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,17 @@ def benchmark(B, M, N, K, provider):
268268
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
269269
if len(a.shape) == 3:
270270
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
271+
kernel_name = 'matmul_kernel_with_block_pointers_batched'
271272
else:
272273
assert len(a.shape) == 2, 'Expecting shape of length 2'
273274
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
275+
kernel_name = 'matmul_kernel_with_block_pointers'
274276
triton_fn = lambda: matmul(a, b, c)
275277
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
276278
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
277279
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
278-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
280+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
281+
kernel_name=kernel_name)
279282
else:
280283
raise NotImplementedError(f'Unsupported provider {provider}')
281284

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,17 @@ def benchmark(B, M, N, K, provider):
256256
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
257257
if len(a.shape) == 3:
258258
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
259+
kernel_name = 'matmul_kernel_with_block_pointers_batched'
259260
else:
260261
assert len(a.shape) == 2, 'Expecting shape of length 2'
261262
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
263+
kernel_name = 'matmul_kernel_with_block_pointers'
262264
triton_fn = lambda: matmul(a, b, c)
263265
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
264266
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
265267
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
266-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
268+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
269+
kernel_name=kernel_name)
267270
else:
268271
raise NotImplementedError(f'Unsupported provider {provider}')
269272

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def benchmark(M, N, K, provider):
157157
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
158158
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
159159
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
160-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
160+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
161+
kernel_name='_kernel')
161162
else:
162163
raise NotImplementedError(f'Unsupported provider {provider}')
163164

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def benchmark(M, N, K, provider):
278278
triton_fn = lambda: matmul(a, b, c)
279279
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
280280
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
281-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
281+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
282+
kernel_name=['first_wave', 'full_tiles'])
282283
else:
283284
raise NotImplementedError(f'Unsupported provider {provider}')
284285

0 commit comments

Comments
 (0)