diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index ed8791c727..7bb49d6617 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -17,6 +17,7 @@ on: options: - PYTORCH_LEGACY_PROFILER_USING_IPEX - ELAPSED_TIME + - UPSTREAM_PYTORCH_PROFILER default: PYTORCH_LEGACY_PROFILER_USING_IPEX schedule: - cron: "5 23 * * *" diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 916c4c10d8..311c060eeb 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -7,7 +7,7 @@ if USE_IPEX_OPTION: BENCHMARKING_METHOD = "PYTORCH_LEGACY_PROFILER_USING_IPEX" else: - BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "ELAPSED_TIME") + BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") def synchronize(): @@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode): def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", - device="xpu", sync_submitting=True): + device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -127,7 +127,7 @@ def extract_kernels(funcs): def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, - return_mode="mean", device="xpu"): + return_mode="mean", device="xpu", kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -155,10 +155,98 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N return _summarize_statistics(times, quantiles, return_mode) +def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, + return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + + assert return_mode in ["min", "max", "mean", "median"] + import torch + from torch.profiler import profile, ProfilerActivity + + fn() + synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + if fast_flush: + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) + else: + cache = torch.empty(int(cache_size), dtype=torch.int8, device=device) + + # Estimate the runtime of the function + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof: + for _ in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + if sync_submitting: + synchronize() + # record time of `fn` + fn() + # Record clocks + synchronize() + + function_events = prof.events() + + functions = [] + if isinstance(kernel_name, str): + kernel_name = [kernel_name] + for ker_name in kernel_name: + functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop + # profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events) + + assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}" + # Make the time to the milliseconds. + times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float) + return _summarize_statistics(times, quantiles, return_mode) + + if BENCHMARKING_METHOD == "PYTORCH_LEGACY_PROFILER_USING_IPEX": do_bench = do_bench_ipex elif BENCHMARKING_METHOD == "ELAPSED_TIME": do_bench = do_bench_elapsed_time +elif BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": + do_bench = do_bench_upstream_pytorch_profiler else: raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index ae49da2d0c..d881e51c29 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -257,7 +257,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name='_attn_fwd') elif provider == 'xetla': module_name = f'flash_attn_causal_{CAUSAL}'.lower() @@ -272,7 +273,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 0b983448e4..dfd080ac34 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -131,7 +131,8 @@ def benchmark(M, N, provider): triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10, + kernel_name="softmax_kernel") elif provider == "torch-jit": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, @@ -144,7 +145,17 @@ def benchmark(M, N, provider): xetla_fn = lambda: func(x, out, 0) torch_fn = lambda: torch.softmax(x, axis=-1) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) + kernels_name = { + "softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0", + "softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0", + "softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0", + "softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0", + "softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0", + "softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0", + "softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0", + } + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10, + kernel_name=kernels_name[name]) else: raise NotImplementedError(f"Unsupported provider {provider}") diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index b647de0d23..8c6c14f816 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -262,7 +262,8 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name='matmul_kernel_with_block_pointers') elif provider == 'xetla': if B == 1: c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -276,8 +277,37 @@ def benchmark(B, M, N, K, provider): func = getattr(xetla_kernel, name) xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + + kernels_name = { + 'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row', + 'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row', + 'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row', + 'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row', + 'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row', + 'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row', + 'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row', + 'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row', + 'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row', + 'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row', + 'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row', + 'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row', + 'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row', + 'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row', + 'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row', + 'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row', + 'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row', + 'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row', + 'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row', + 'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row', + 'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row', + 'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row', + 'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row', + 'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row', + } + # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name=kernels_name[name]) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 08080abf96..8be5f6ce01 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -266,14 +266,17 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) + kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, d, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 788354e978..f504aa9952 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -268,14 +268,17 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) + kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index e273cf4366..3e9836e225 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -256,14 +256,17 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) + kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name=kernel_name) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 468393be88..0ef444a3ab 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -157,7 +157,8 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name='_kernel') else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 65d5070212..5a6b3f4fe2 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -278,7 +278,8 @@ def benchmark(M, N, K, provider): triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + kernel_name=['first_wave', 'full_tiles']) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/prefix_sums.py b/benchmarks/triton_kernels_benchmark/prefix_sums.py index bb3d2069f0..8f17fb9e9f 100644 --- a/benchmarks/triton_kernels_benchmark/prefix_sums.py +++ b/benchmarks/triton_kernels_benchmark/prefix_sums.py @@ -44,7 +44,8 @@ def benchmark(M, N, AXIS, provider): if provider == 'triton': triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS) - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, + kernel_name='scan_kernel') else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/scripts/capture-hw-details.sh b/scripts/capture-hw-details.sh index 4f658507e1..2fbcc9fd2c 100755 --- a/scripts/capture-hw-details.sh +++ b/scripts/capture-hw-details.sh @@ -58,12 +58,6 @@ else export COMPILER_VERSION="Not installed" fi -if [[ "${USE_IPEX:-}" == "1" ]]; then - export BENCHMARKING_METHOD="PYTORCH_LEGACY_PROFILER_USING_IPEX" -elif [[ "${USE_IPEX:-}" == "0" ]]; then - export BENCHMARKING_METHOD="${BENCHMARKING_METHOD:-ELAPSED_TIME}" -fi - if [ "$QUIET" = false ]; then echo "LIBIGC1_VERSION=$LIBIGC1_VERSION" echo "LEVEL_ZERO_VERSION=$LEVEL_ZERO_VERSION"