diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index 746dc57734..ed49066a2c 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -25,6 +25,7 @@ def get_flex_attn_fwd_configs(*args, **kwargs): # pylint: disable=unused-argument configs = [ FlexConfig(32, 16, 2, 4), + FlexConfig(64, 32, 2, 4), FlexConfig(128, 64, 2, 16), FlexConfig(128, 64, 2, 8), FlexConfig(128, 32, 2, 16),