From 8abe5549d40897cd456dca331c9e5c5ba98b6d8b Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 14 Oct 2024 11:05:25 +0000 Subject: [PATCH 1/5] Remove workaround for upstream profiler Signed-off-by: Anatoly Myachev --- .../benchmark_testing.py | 26 ++++++++++--------- third_party/intel/backend/driver.py | 12 ++++++++- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index b289a096e2..71f3a1ea25 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -149,7 +149,7 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", - device="xpu", sync_submitting=True, kernel_name=None): + device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -168,7 +168,7 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None assert return_mode in ["min", "max", "mean", "median"] import torch - from torch.profiler import profile, ProfilerActivity + from torch.profiler import profile, ProfilerActivity, record_function fn() synchronize() @@ -210,22 +210,24 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None if sync_submitting: synchronize() # record time of `fn` - fn() + with record_function("__profile_kernel_of_func"): + fn() # Record clocks synchronize() - function_events = prof.events() + profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.events()) + functions = list(profiling_func_filter) - functions = [] - if isinstance(kernel_name, str): - kernel_name = [kernel_name] - for ker_name in kernel_name: - functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop - # profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events) + def extract_kernels(funcs): + kernels = [] + kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs))) + kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs])) + return kernels - assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}" + kernels = [extract_kernels(func.cpu_children) for func in functions] + assert len(kernels) == n_repeat, "the profiling number not match" # Make the time to the milliseconds. - times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float) + times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 2f01cb93df..cfa9718699 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -71,10 +71,11 @@ class CompilationHelper: def __init__(self): self._library_dir = None self._include_dir = None - self.libraries = ['ze_loader', 'sycl'] + self.libraries = ['ze_loader', 'sycl', 'torch'] @cached_property def _compute_compilation_options_lazy(self): + import torch ze_root = os.getenv("ZE_PATH", default="/usr/local") include_dir = [os.path.join(ze_root, "include")] @@ -82,7 +83,14 @@ def _compute_compilation_options_lazy(self): dirname = os.path.dirname(os.path.realpath(__file__)) include_dir += [os.path.join(dirname, "include")] + include_dir += [ + os.path.join(torch.utils.cmake_prefix_path, "../../include"), + os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"), + ] library_dir += [os.path.join(dirname, "lib")] + library_dir += [ + os.path.join(torch.utils.cmake_prefix_path, "../../lib"), + ] self._library_dir = library_dir self._include_dir = include_dir @@ -218,6 +226,7 @@ def format_of(ty): #include #include #include + #include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include @@ -310,6 +319,7 @@ def format_of(ty): static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ std::string kernel_name = kernel_ptr.get_info(); + RECORD_FUNCTION("XPU Triton kernel: " + kernel_name, {{}}); void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; uint32_t num_params = sizeof(params)/sizeof(params[0]); uint32_t expected_num_params = kernel_ptr.get_info(); From 0b90c7e4030a5497ac36d8d48451bf9fb3070b0f Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 21 Oct 2024 09:40:24 +0000 Subject: [PATCH 2/5] cleanup Signed-off-by: Anatoly Myachev --- third_party/intel/backend/driver.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 71d16c8056..92aeb1f44d 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -71,11 +71,10 @@ class CompilationHelper: def __init__(self): self._library_dir = None self._include_dir = None - self.libraries = ['ze_loader', 'sycl', 'torch'] + self.libraries = ['ze_loader', 'sycl'] @cached_property def _compute_compilation_options_lazy(self): - import torch ze_root = os.getenv("ZE_PATH", default="/usr/local") include_dir = [os.path.join(ze_root, "include")] @@ -83,14 +82,7 @@ def _compute_compilation_options_lazy(self): dirname = os.path.dirname(os.path.realpath(__file__)) include_dir += [os.path.join(dirname, "include")] - include_dir += [ - os.path.join(torch.utils.cmake_prefix_path, "../../include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"), - ] library_dir += [os.path.join(dirname, "lib")] - library_dir += [ - os.path.join(torch.utils.cmake_prefix_path, "../../lib"), - ] self._library_dir = library_dir self._include_dir = include_dir @@ -226,7 +218,6 @@ def format_of(ty): #include #include #include - #include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include @@ -319,7 +310,6 @@ def format_of(ty): static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ std::string kernel_name = kernel_ptr.get_info(); - RECORD_FUNCTION("XPU Triton kernel: " + kernel_name, {{}}); void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; uint32_t num_params = sizeof(params)/sizeof(params[0]); uint32_t expected_num_params = kernel_ptr.get_info(); From 0d5f3fcf3d45f055849236c4d97f9d6675d08f24 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 21 Oct 2024 10:17:14 +0000 Subject: [PATCH 3/5] remove 'kernel_name' Signed-off-by: Anatoly Myachev --- .../benchmark_testing.py | 6 ++-- .../flash_attention_fwd_benchmark.py | 6 ++-- .../triton_kernels_benchmark/fused_softmax.py | 15 ++------ .../gemm_benchmark.py | 35 ++----------------- .../gemm_postop_addmatrix_benchmark.py | 4 +-- .../gemm_postop_gelu_benchmark.py | 4 +-- .../gemm_preop_exp_benchmark.py | 4 +-- .../gemm_splitk_benchmark.py | 3 +- .../gemm_streamk_benchmark.py | 5 ++- .../triton_kernels_benchmark/prefix_sums.py | 3 +- 10 files changed, 17 insertions(+), 68 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index b85580dcbf..e840d4769c 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -37,7 +37,7 @@ def _summarize_statistics(times, quantiles, return_mode): def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu", - sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument + sync_submitting=True): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -108,7 +108,7 @@ def extract_kernels(funcs): def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", - device="xpu", kernel_name=None): # pylint: disable=unused-argument + device="xpu"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -160,7 +160,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, - return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument + return_mode="mean", device="xpu", sync_submitting=True): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index dc073d0e5c..8604824cda 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -256,8 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - kernel_name='_attn_fwd') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'xetla': module_name = f'flash_attn_causal_{CAUSAL}'.lower() @@ -272,8 +271,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - kernel_name='gpu::xetla::fmha::FmhaForwardKernel<') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 3f17ac4a55..b12ed819f7 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -131,8 +131,7 @@ def benchmark(M, N, provider): triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, - kernel_name="softmax_kernel") + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10) elif provider == "torch-jit": _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, @@ -145,17 +144,7 @@ def benchmark(M, N, provider): xetla_fn = lambda: func(x, out, 0) torch_fn = lambda: torch.softmax(x, axis=-1) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") - kernels_name = { - "softmax_shape_4096_256": "mat1_4096x256_bf16_cfg0", - "softmax_shape_4096_1024": "mat1_4096x1024_bf16_cfg0", - "softmax_shape_4096_2048": "mat1_4096x2048_bf16_cfg0", - "softmax_shape_4096_4096": "mat1_4096x4096_bf16_cfg0", - "softmax_shape_4096_8192": "mat1_4096x8k_bf16_cfg0", - "softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0", - "softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0", - } - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10, - kernel_name=kernels_name[name]) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10) else: raise NotImplementedError(f"Unsupported provider {provider}") diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 7e0b339dc6..c58313263c 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -284,7 +284,7 @@ def benchmark(B, M, N, K, provider): # Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method do_bench = do_bench_elapsed_time _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name='gemm_kernel') + quantiles=quantiles) elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: @@ -297,8 +297,7 @@ def benchmark(B, M, N, K, provider): rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, - kernel_name='matmul_kernel_with_block_pointers') + quantiles=quantiles) elif provider == 'xetla': if B == 1: c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -317,37 +316,9 @@ def benchmark(B, M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - kernels_name = { - 'gemm_shape_1_1024_1024_1024': 'Test_1x1024x1024x1024_row_row', - 'gemm_shape_1_2048_2048_2048': 'Test_1x2048x2048x2048_row_row', - 'gemm_shape_1_4096_4096_4096': 'Test_1x4096x4096x4096_row_row', - 'gemm_shape_1_8192_8192_8192': 'Test_1x8192x8192x8192_row_row', - 'gemm_shape_1_1_5120_13824': 'Test_1x1x5120x13824_row_row', - 'gemm_shape_1_4_4096_12288': 'Test_1x4x4096x12288_row_row', - 'gemm_shape_1_512_8192_8192': 'Test_1x512x8192x8192_row_row', - 'gemm_shape_1_512_8192_32768': 'Test_1x512x8192x32768_row_row', - 'gemm_shape_1_512_32768_8192': 'Test_1x512x32768x8192_row_row', - 'gemm_shape_1_1024_16384_8192': 'Test_1x1024x16384x8192_row_row', - 'gemm_shape_1_1024_28672_8192': 'Test_1x1024x28672x8192_row_row', - 'gemm_shape_1_3072_4096_3072': 'Test_1x3072x4096x3072_row_row', - 'gemm_shape_1_4096_16384_8192': 'Test_1x4096x16384x8192_row_row', - 'gemm_shape_1_8192_16384_1024': 'Test_1x8192x16384x1024_row_row', - 'gemm_shape_1_8192_16384_4096': 'Test_1x8192x16384x4096_row_row', - 'gemm_shape_1_16384_1024_8192': 'Test_1x16384x1024x8192_row_row', - 'gemm_shape_1_16384_4096_8192': 'Test_1x16384x4096x8192_row_row', - 'gemm_shape_1_16384_8192_1024': 'Test_1x16384x8192x1024_row_row', - 'gemm_shape_1_16384_8192_4096': 'Test_1x16384x8192x4096_row_row', - 'gemm_shape_4_32768_128_4096': 'Test_4x32768x128x4096_row_row', - 'gemm_shape_4_32768_4096_128': 'Test_4x32768x4096x128_row_row', - 'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row', - 'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row', - 'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row', - 'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run', - } - # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernels_name[name]) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 307100dcfe..cefbd5abc9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -266,17 +266,15 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, d, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernel_name) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 85bb594ade..68cec3931e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -268,17 +268,15 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernel_name) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 30ed124d44..dd5b57c84f 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -256,17 +256,15 @@ def benchmark(B, M, N, K, provider): assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' c = torch.empty((M, N), device='xpu', dtype=torch.float32) - kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name=kernel_name) + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 4aa1910591..4eb4c2b3e8 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -156,8 +156,7 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, - kernel_name='_kernel') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index a495dca749..6969506e65 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -280,8 +280,7 @@ def benchmark(M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, - kernel_name=['first_wave', 'full_tiles']) + quantiles=quantiles) elif provider == 'xetla': c = torch.empty((M, N), device='xpu', dtype=torch.float32) acc = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -294,7 +293,7 @@ def benchmark(M, N, K, provider): # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, - quantiles=quantiles, kernel_name='stream_k_gemm_run') + quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/prefix_sums.py b/benchmarks/triton_kernels_benchmark/prefix_sums.py index 8f17fb9e9f..bb3d2069f0 100644 --- a/benchmarks/triton_kernels_benchmark/prefix_sums.py +++ b/benchmarks/triton_kernels_benchmark/prefix_sums.py @@ -44,8 +44,7 @@ def benchmark(M, N, AXIS, provider): if provider == 'triton': triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS) - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, - kernel_name='scan_kernel') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') From 2a4b81888266fb1829b1d4b739001d3c2d4270cd Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 18 Dec 2024 16:02:55 +0000 Subject: [PATCH 4/5] try changes from #3036 Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/__init__.py | 2 +- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 8 ++++++++ .../flash_attention_fwd_benchmark.py | 2 +- benchmarks/triton_kernels_benchmark/fused_softmax.py | 1 + benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 2 ++ .../gemm_postop_addmatrix_benchmark.py | 2 ++ .../gemm_postop_gelu_benchmark.py | 2 ++ .../triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 2 ++ .../triton_kernels_benchmark/gemm_splitk_benchmark.py | 1 + .../triton_kernels_benchmark/gemm_streamk_benchmark.py | 2 ++ python/triton/runtime/autotuner.py | 2 +- 11 files changed, 23 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index 02857fdd99..43b1f9722b 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -1,4 +1,4 @@ -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401 +from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401 if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": from triton.runtime import driver diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 1e088291fa..30f2467e97 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -237,6 +237,14 @@ def extract_kernels(funcs): raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") +def make_do_bench_for_autotune(): + + def autotuner_do_bench(*args, **kwargs): + return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs) + + return autotuner_do_bench + + def assert_close(x, y, atol=None, rtol=None, err_msg=""): import numpy as np import torch diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 132898c023..ed63118ab0 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # for w in [8, 16, 32] \ ] -tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL']) +tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune()) tune_attn_fwd = tuner(_attn_fwd) diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 6782e92d6b..56cd91befe 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -50,6 +50,7 @@ def naive_softmax(x): triton.Config({"threads_per_warp": 16}, num_warps=4), ], key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr, diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 4ad3d8d5e5..a9b064714d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -43,6 +43,7 @@ num_stages=s, num_warps=32) for s in [2, 3] ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -116,6 +117,7 @@ def matmul_kernel_with_block_pointers( num_stages=s, num_warps=4) for s in [2] ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index cefbd5abc9..7d40709845 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -35,6 +35,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -109,6 +110,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 68cec3931e..7ee5038b85 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -54,6 +54,7 @@ def gelu(x): num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -122,6 +123,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index dd5b57c84f..6d821b4f30 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -36,6 +36,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -107,6 +108,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index c4114c4466..8b354c8cd2 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -15,6 +15,7 @@ num_stages=4, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def _kernel(A, B, C, # diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index f0743cfe64..fa209319a1 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -107,6 +107,7 @@ def mac_loop( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def first_wave( @@ -143,6 +144,7 @@ def first_wave( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def full_tiles( diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 573d9d4191..7e93086214 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -357,7 +357,7 @@ def kernel(x_ptr, x_size, **META): def decorator(fn): return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph) + use_cuda_graph=use_cuda_graph, do_bench=do_bench) return decorator From 0d66c8e854c95b84344e1d77ce5d111014a95e73 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 18 Dec 2024 21:14:23 +0100 Subject: [PATCH 5/5] Revert "try changes from #3036" This reverts commit 2a4b81888266fb1829b1d4b739001d3c2d4270cd. --- benchmarks/triton_kernels_benchmark/__init__.py | 2 +- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 8 -------- .../flash_attention_fwd_benchmark.py | 2 +- benchmarks/triton_kernels_benchmark/fused_softmax.py | 1 - benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 2 -- .../gemm_postop_addmatrix_benchmark.py | 2 -- .../gemm_postop_gelu_benchmark.py | 2 -- .../triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 2 -- .../triton_kernels_benchmark/gemm_splitk_benchmark.py | 1 - .../triton_kernels_benchmark/gemm_streamk_benchmark.py | 2 -- python/triton/runtime/autotuner.py | 2 +- 11 files changed, 3 insertions(+), 23 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index 43b1f9722b..02857fdd99 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -1,4 +1,4 @@ -from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401 +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401 if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": from triton.runtime import driver diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 30f2467e97..1e088291fa 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -237,14 +237,6 @@ def extract_kernels(funcs): raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") -def make_do_bench_for_autotune(): - - def autotuner_do_bench(*args, **kwargs): - return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs) - - return autotuner_do_bench - - def assert_close(x, y, atol=None, rtol=None, err_msg=""): import numpy as np import torch diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index ed63118ab0..132898c023 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # for w in [8, 16, 32] \ ] -tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune()) +tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL']) tune_attn_fwd = tuner(_attn_fwd) diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 56cd91befe..6782e92d6b 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -50,7 +50,6 @@ def naive_softmax(x): triton.Config({"threads_per_warp": 16}, num_warps=4), ], key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr, diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index a9b064714d..4ad3d8d5e5 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -43,7 +43,6 @@ num_stages=s, num_warps=32) for s in [2, 3] ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -117,7 +116,6 @@ def matmul_kernel_with_block_pointers( num_stages=s, num_warps=4) for s in [2] ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 7d40709845..cefbd5abc9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -35,7 +35,6 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -110,7 +109,6 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 7ee5038b85..68cec3931e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -54,7 +54,6 @@ def gelu(x): num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -123,7 +122,6 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 6d821b4f30..dd5b57c84f 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -36,7 +36,6 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -108,7 +107,6 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 8b354c8cd2..c4114c4466 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -15,7 +15,6 @@ num_stages=4, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def _kernel(A, B, C, # diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index fa209319a1..f0743cfe64 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -107,7 +107,6 @@ def mac_loop( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def first_wave( @@ -144,7 +143,6 @@ def first_wave( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def full_tiles( diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 7e93086214..573d9d4191 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -357,7 +357,7 @@ def kernel(x_ptr, x_size, **META): def decorator(fn): return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph, do_bench=do_bench) + use_cuda_graph=use_cuda_graph) return decorator