From fca27a5f0db17808b449c830ad0611d4314816dc Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Tue, 2 Dec 2025 10:06:20 +0000 Subject: [PATCH 1/3] empty cache between each run to avoid OOM --- .../triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index d45075260f..4d76af9f01 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -311,6 +311,7 @@ def benchmark(B, M, N, K, dtype, provider): # Maximum across onednn=600, triton=1000 # For onednn and triton: Some configs increase performance with warmup as a step function, but some # slowly decrease with saturation. Performance is best at 150-200ms range, but we want stable, not just best + torch.xpu.empty_cache() do_bench = benchmark_suite.get_do_bench(n_warmup=1000, n_repeat=10, quantiles=[0.5, 0.0, 1.0]) res_dtype = torch.float32 if dtype.is_floating_point else torch.int32 if dtype.is_floating_point: From 19e276e70af9dd07b2f069c02e6d8fc5f731fa94 Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Wed, 3 Dec 2025 08:46:25 +0000 Subject: [PATCH 2/3] add cache flush to flex att bs=16 --- .../flex_attention_benchmark_causal_mask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index 3b4871d073..95c3e1b2a8 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -149,6 +149,7 @@ def causal_mask(_, __, q_idx, kv_idx): )) def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider): # Maximum across torch=200, triton=600 + torch.xpu.empty_cache() do_bench = benchmark_suite.get_do_bench(n_warmup=600, n_repeat=10, quantiles=[0.5, 0.0, 1.0]) if MODE not in ('fwd', 'bwd'): raise ValueError(f"Invalid MODE: {MODE}. Expected 'fwd' or 'bwd'.") From facac0f69af9776bd9cbf76a473db6ed4ced1d5b Mon Sep 17 00:00:00 2001 From: dev-tomek Date: Thu, 4 Dec 2025 13:25:12 +0000 Subject: [PATCH 3/3] Revert "add cache flush to flex att bs=16" This reverts commit 19e276e70af9dd07b2f069c02e6d8fc5f731fa94. --- .../flex_attention_benchmark_causal_mask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index 95c3e1b2a8..3b4871d073 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -149,7 +149,6 @@ def causal_mask(_, __, q_idx, kv_idx): )) def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider): # Maximum across torch=200, triton=600 - torch.xpu.empty_cache() do_bench = benchmark_suite.get_do_bench(n_warmup=600, n_repeat=10, quantiles=[0.5, 0.0, 1.0]) if MODE not in ('fwd', 'bwd'): raise ValueError(f"Invalid MODE: {MODE}. Expected 'fwd' or 'bwd'.")