Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION # type: ignore # noqa: F401

if USE_IPEX_OPTION:
from triton.runtime import driver
from . import benchmark_driver
# replace the launcher with the profilier hook.
driver.active.launcher_cls = benchmark_driver.XPULauncher
12 changes: 5 additions & 7 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from triton.runtime.build import _build, quiet

import torch
import intel_extension_for_pytorch
# import intel_extension_for_pytorch

_dirname = os.getenv("ZE_PATH", default="/usr/local")

include_dir = [
os.path.join(_dirname, "include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"),
os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include")
# os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include")
]

oneapi_root = os.getenv("ONEAPI_ROOT")
Expand All @@ -31,9 +31,10 @@
library_dir = [
os.path.join(_dirname, "lib"),
os.path.join(torch.utils.cmake_prefix_path, "../../lib"),
os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib")
# os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib")
]
libraries = ["ze_loader", "sycl", "torch", "intel-ext-pt-gpu"]

libraries = ["ze_loader", "sycl", "torch"]


def compile_module_from_src(src, name):
Expand Down Expand Up @@ -150,7 +151,6 @@ def format_of(ty):
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
#include <torch/extension.h>
#include <ipex.h>

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
Expand Down Expand Up @@ -261,7 +261,6 @@ def format_of(ty):
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{

std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
void *params[] = {{ {", ".join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
uint32_t num_params = sizeof(params)/sizeof(params[0]);
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
Expand Down Expand Up @@ -291,7 +290,6 @@ def format_of(ty):
}}
}};
auto event = stream.submit(cgf);
xpu::profiler_record(kernel_name, event);
}}
// end sycl
static PyObject* launch(PyObject* self, PyObject* args) {{
Expand Down
99 changes: 79 additions & 20 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _summarize_statistics(times, quantiles, return_mode):


def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't changed the default implementation yet so I can switch and see how the new implementation behaves relative to the old one.

device="xpu", sync_submitting=True):
device="xpu", sync_submitting=True, kernel_name=None):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -56,7 +56,7 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fas

assert return_mode in ["min", "max", "mean", "median"]
import torch
from torch.autograd.profiler import record_function
from torch.profiler import profile, ProfilerActivity

fn()
synchronize()
Expand Down Expand Up @@ -88,7 +88,7 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fas
for _ in range(n_warmup):
fn()
# Benchmark
with torch.autograd.profiler_legacy.profile(True, use_xpu=True) as prof:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
for _ in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
Expand All @@ -101,29 +101,27 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fas
if sync_submitting:
synchronize()
# record time of `fn`
with record_function("__profile_kernel_of_func"):
fn()
fn()
# Record clocks
synchronize()

profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.function_events)
functions = list(profiling_func_filter)
# breakpoint()
# print(prof.key_averages(group_by_stack_n=5).table(sort_by="xpu_time"))
# print(prof.key_averages(group_by_stack_n=5).table)

def extract_kernels(funcs):
kernels = []
kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs)))
kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs]))
return kernels
function_events = prof.events()
profiling_func_filter = filter(lambda x: x.name.endswith(kernel_name), function_events)
#profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
functions = list(profiling_func_filter)

kernels = [extract_kernels(func.cpu_children) for func in functions]
assert len(kernels) == n_repeat, "the profiling number not match"
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
# Make the time to the milliseconds.
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


def do_bench_no_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device="xpu"):
device="xpu", sync_submitting=True, kernel_name=None):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -141,13 +139,74 @@ def do_bench_no_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None,
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
# TODO: remove this function and switch to `do_bench_no_ipex` after
# `XPUEvent.elapsed_time` stops introducing regressions into the results.

assert return_mode in ["min", "max", "mean", "median"]
import torch
from triton.testing import do_bench as triton_do_bench
from torch.profiler import profile, ProfilerActivity

fn()
synchronize()

times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, fast_flush=fast_flush,
return_mode="all", device_type=device)
times = torch.tensor(times, dtype=torch.float)
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
if fast_flush:
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
else:
cache = torch.empty(int(cache_size), dtype=torch.int8, device=device)

# Estimate the runtime of the function
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
for _ in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
if sync_submitting:
synchronize()
# record time of `fn`
fn()
Comment on lines +224 to +225
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this function is a copy of the do_bench_ipex function.

However, we can't use the following code because the kernel we need isn't among the subevents of this event (because of a bug I guess):

with record_function("__profile_kernel_of_func"):
    fn()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure a ticket is created for this if it is not already.

# Record clocks
synchronize()

# print(prof.key_averages(group_by_stack_n=5).table(sort_by="xpu_time"))
# print(prof.key_averages(group_by_stack_n=5).table)
function_events = prof.events()

functions = []
if isinstance(kernel_name, str):
kernel_name = [kernel_name]
for ker_name in kernel_name:
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)

assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
# Make the time to the milliseconds.
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def forward(q, k, v, causal, sm_scale):
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['triton', 'xetla'],
line_vals=['xetla'],
# label name for the lines
line_names=['Triton', 'XeTLA'],
line_names=['XeTLA'],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
Expand Down Expand Up @@ -241,7 +241,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name='_attn_fwd')

elif provider == 'xetla':
module_name = f'flash_attn_causal_{causal}'.lower()
Expand All @@ -257,7 +257,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False,
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
15 changes: 13 additions & 2 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def benchmark(M, N, provider):
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
kernel_name="softmax_kernel")

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
Expand All @@ -144,7 +145,17 @@ def benchmark(M, N, provider):
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
kernels_name = {
"softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0",
"softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0",
"softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0",
"softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0",
"softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0",
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
}
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
kernel_name=kernels_name[name])

else:
raise NotImplementedError(f"Unsupported provider {provider}")
Expand Down
33 changes: 31 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def benchmark(B, M, N, K, provider):
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False,
kernel_name='matmul_kernel_with_block_pointers')
elif provider == 'xetla':
if B == 1:
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -277,9 +278,37 @@ def benchmark(B, M, N, K, provider):
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

kernels_name = {
'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row',
'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row',
'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row',
'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row',
'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row',
'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row',
'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row',
'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row',
'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row',
'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row',
'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row',
'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row',
'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row',
'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row',
'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row',
'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row',
'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row',
'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row',
'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row',
'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row',
'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row',
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
}

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernels_name[name])
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,17 @@ def benchmark(B, M, N, K, provider):
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers_batched'
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
kernel_name = 'matmul_kernel_with_block_pointers'
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
fast_flush=False, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Loading