diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index 6f86d4d953..e08a0135bd 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -74,6 +74,12 @@ def causal_mask(_, __, q_idx, kv_idx): batch_sizes = [16, 32, 64] if throughput_test else [batch_size] fa_kernel_mode = os.getenv('FA_KERNEL_MODE', 'fwd') +if torch.xpu.get_device_name() == '580': + old_count = len(batch_sizes) + batch_sizes = [size for size in batch_sizes if size < 16] + if len(batch_sizes) != old_count: + print('Skipping running batch_sizes >= 16 on b580') + # Kernel profiling for Backward mode is not working as expected: # For details: https://github.com/pytorch/pytorch/issues/144778