Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _summarize_statistics(times, quantiles, return_mode):


def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't changed the default implementation yet so I can switch and see how the new implementation behaves relative to the old one.

device="xpu", sync_submitting=True):
device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -123,7 +123,7 @@ def extract_kernels(funcs):


def do_bench_no_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device="xpu"):
device="xpu", sync_submitting=True, kernel_name=None):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -141,13 +141,70 @@ def do_bench_no_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None,
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""

assert return_mode in ["min", "max", "mean", "median"]
import torch
from triton.testing import do_bench as triton_do_bench
from torch.profiler import profile, ProfilerActivity

fn()
synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
if fast_flush:
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
else:
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device)

times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, fast_flush=fast_flush,
return_mode="all", device_type=device)
times = torch.tensor(times, dtype=torch.float)
# Estimate the runtime of the function
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
for _ in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
if sync_submitting:
synchronize()
# record time of `fn`
fn()
Comment on lines +224 to +225
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this function is a copy of the do_bench_ipex function.

However, we can't use the following code because the kernel we need isn't among the subevents of this event (because of a bug I guess):

with record_function("__profile_kernel_of_func"):
    fn()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure a ticket is created for this if it is not already.

# Record clocks
synchronize()

function_events = prof.events()

functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)

assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
# Make the time to the milliseconds.
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name='_attn_fwd')

elif provider == 'xetla':
module_name = f'flash_attn_causal_{causal}'.lower()
Expand All @@ -257,7 +257,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False,
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
15 changes: 13 additions & 2 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def benchmark(M, N, provider):
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
kernel_name="softmax_kernel")

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
Expand All @@ -144,7 +145,17 @@ def benchmark(M, N, provider):
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
kernels_name = {
"softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0",
"softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0",
"softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0",
"softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0",
"softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0",
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
}
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
kernel_name=kernels_name[name])

else:
raise NotImplementedError(f"Unsupported provider {provider}")
Expand Down
33 changes: 31 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def benchmark(B, M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False,
kernel_name='matmul_kernel_with_block_pointers')
elif provider == 'xetla':
if B == 1:
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -277,9 +278,37 @@ def benchmark(B, M, N, K, provider):
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

kernels_name = {
'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row',
'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row',
'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row',
'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row',
'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row',
'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row',
'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row',
'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row',
'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row',
'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row',
'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row',
'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row',
'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row',
'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row',
'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row',
'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row',
'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row',
'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row',
'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row',
'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row',
'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row',
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
}

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernels_name[name])
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def benchmark(M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name='_kernel')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def benchmark(M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False,
kernel_name=['first_wave', 'full_tiles'])
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/triton_kernels_benchmark/prefix_sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def benchmark(M, N, AXIS, provider):

if provider == 'triton':
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False,
kernel_name='scan_kernel')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down