Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ on:
options:
- PYTORCH_LEGACY_PROFILER_USING_IPEX
- ELAPSED_TIME
- UPSTREAM_PYTORCH_PROFILER
default: PYTORCH_LEGACY_PROFILER_USING_IPEX
schedule:
- cron: "5 23 * * *"
Expand Down
94 changes: 91 additions & 3 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if USE_IPEX_OPTION:
BENCHMARKING_METHOD = "PYTORCH_LEGACY_PROFILER_USING_IPEX"
else:
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "ELAPSED_TIME")
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")


def synchronize():
Expand Down Expand Up @@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode):


def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't changed the default implementation yet so I can switch and see how the new implementation behaves relative to the old one.

device="xpu", sync_submitting=True):
device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -127,7 +127,7 @@ def extract_kernels(funcs):


def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True,
return_mode="mean", device="xpu"):
return_mode="mean", device="xpu", kernel_name=None): # pylint: disable=unused-argument
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -155,10 +155,98 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
return _summarize_statistics(times, quantiles, return_mode)


def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True,
return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.

:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""

assert return_mode in ["min", "max", "mean", "median"]
import torch
from torch.profiler import profile, ProfilerActivity

fn()
synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
if fast_flush:
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
else:
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device)

# Estimate the runtime of the function
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
for _ in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
if sync_submitting:
synchronize()
# record time of `fn`
fn()
Comment on lines +224 to +225
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this function is a copy of the do_bench_ipex function.

However, we can't use the following code because the kernel we need isn't among the subevents of this event (because of a bug I guess):

with record_function("__profile_kernel_of_func"):
    fn()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure a ticket is created for this if it is not already.

# Record clocks
synchronize()

function_events = prof.events()

functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)

assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
# Make the time to the milliseconds.
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


if BENCHMARKING_METHOD == "PYTORCH_LEGACY_PROFILER_USING_IPEX":
do_bench = do_bench_ipex
elif BENCHMARKING_METHOD == "ELAPSED_TIME":
do_bench = do_bench_elapsed_time
elif BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
do_bench = do_bench_upstream_pytorch_profiler
else:
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name='_attn_fwd')

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
Expand All @@ -272,7 +273,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
15 changes: 13 additions & 2 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def benchmark(M, N, provider):
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
kernel_name="softmax_kernel")

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
Expand All @@ -144,7 +145,17 @@ def benchmark(M, N, provider):
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
kernels_name = {
"softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0",
"softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0",
"softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0",
"softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0",
"softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0",
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
}
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
kernel_name=kernels_name[name])

else:
raise NotImplementedError(f"Unsupported provider {provider}")
Expand Down
34 changes: 32 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name='matmul_kernel_with_block_pointers')
elif provider == 'xetla':
if B == 1:
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -276,8 +277,37 @@ def benchmark(B, M, N, K, provider):
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

kernels_name = {
'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row',
'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row',
'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row',
'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row',
'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row',
'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row',
'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row',
'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row',
'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row',
'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row',
'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row',
'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row',
'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row',
'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row',
'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row',
'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row',
'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row',
'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row',
'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row',
'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row',
'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row',
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
}

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernels_name[name])
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def benchmark(M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name='_kernel')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def benchmark(M, N, K, provider):
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/triton_kernels_benchmark/prefix_sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def benchmark(M, N, AXIS, provider):

if provider == 'triton':
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles,
kernel_name='scan_kernel')
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
6 changes: 0 additions & 6 deletions scripts/capture-hw-details.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ else
export COMPILER_VERSION="Not installed"
fi

if [[ "${USE_IPEX:-}" == "1" ]]; then
export BENCHMARKING_METHOD="PYTORCH_LEGACY_PROFILER_USING_IPEX"
elif [[ "${USE_IPEX:-}" == "0" ]]; then
export BENCHMARKING_METHOD="${BENCHMARKING_METHOD:-ELAPSED_TIME}"
fi

if [ "$QUIET" = false ]; then
echo "LIBIGC1_VERSION=$LIBIGC1_VERSION"
echo "LEVEL_ZERO_VERSION=$LEVEL_ZERO_VERSION"
Expand Down