Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode):


def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
sync_submitting=True):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -108,7 +108,7 @@ def extract_kernels(funcs):


def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean",
device="xpu", kernel_name=None): # pylint: disable=unused-argument
device="xpu"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -159,7 +159,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan


def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None,
return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None):
return_mode="mean", device="xpu", sync_submitting=True):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -178,7 +178,7 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no

assert return_mode in ["min", "max", "mean", "median"]
import torch
from torch.profiler import profile, ProfilerActivity
from torch.profiler import profile, ProfilerActivity, record_function

fn()
synchronize()
Expand Down Expand Up @@ -206,24 +206,24 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
if sync_submitting:
synchronize()
# record time of `fn`
fn()
with record_function("__profile_kernel_of_func"):
fn()
# Record clocks
synchronize()

function_events = prof.events()
profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.events())
functions = list(profiling_func_filter)

all_functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
all_functions.append(functions)
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
def extract_kernels(funcs):
kernels = []
kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs)))
kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs]))
return kernels

kernels = [extract_kernels(func.cpu_children) for func in functions]
assert len(kernels) == n_repeat, "the profiling number not match"
# Make the time to the milliseconds.
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
dtype=torch.float)
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='_attn_fwd')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
Expand All @@ -281,8 +280,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
15 changes: 2 additions & 13 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def benchmark(M, N, provider):
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
kernel_name="softmax_kernel")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10)

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles,
Expand All @@ -145,17 +144,7 @@ def benchmark(M, N, provider):
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
kernels_name = {
"softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0",
"softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0",
"softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0",
"softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0",
"softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0",
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
}
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
kernel_name=kernels_name[name])
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10)

else:
raise NotImplementedError(f"Unsupported provider {provider}")
Expand Down
35 changes: 3 additions & 32 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def benchmark(B, M, N, K, provider):
# Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
do_bench = do_bench_elapsed_time
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='gemm_kernel')
quantiles=quantiles)
elif provider == 'triton':
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
Expand All @@ -301,8 +301,7 @@ def benchmark(B, M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name='matmul_kernel_with_block_pointers')
quantiles=quantiles)
elif provider == 'xetla':
if B == 1:
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -321,37 +320,9 @@ def benchmark(B, M, N, K, provider):
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

kernels_name = {
'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row',
'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row',
'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row',
'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row',
'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row',
'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row',
'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row',
'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row',
'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row',
'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row',
'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row',
'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row',
'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row',
'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row',
'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row',
'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row',
'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row',
'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row',
'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row',
'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row',
'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row',
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run',
}

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernels_name[name])
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,15 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,15 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,15 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def benchmark(M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='_kernel')
quantiles=quantiles)
elif provider == 'xetla':
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -172,7 +172,7 @@ def benchmark(M, N, K, provider):

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='split_k_gemm_run')
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ def benchmark(M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
quantiles=quantiles)
elif provider == 'xetla':
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -294,7 +293,7 @@ def benchmark(M, N, K, provider):

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='stream_k_gemm_run')
quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
3 changes: 1 addition & 2 deletions benchmarks/triton_kernels_benchmark/prefix_sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def benchmark(M, N, AXIS, provider):

if provider == 'triton':
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles,
kernel_name='scan_kernel')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down