@@ -44,10 +44,10 @@ def setUpClass(cls):
44
44
45
45
@parameterized .expand (
46
46
[
47
- ("eager_paged " , 64 , 128 , 64 ),
48
- ("sdpa_paged " , 32 , 256 , 128 ),
49
- ("paged_attention " , 16 , 512 , 256 ),
50
- ("flex_paged " , 64 , 128 , 64 ),
47
+ ("paged|eager " , 64 , 128 , 64 ),
48
+ ("paged|sdpa " , 32 , 256 , 128 ),
49
+ ("paged|flash_attention_2 " , 16 , 512 , 256 ),
50
+ ("paged|flex_attention " , 64 , 128 , 64 ),
51
51
]
52
52
)
53
53
def test_generate_batch_consistency (self , attn_impl , num_blocks , block_size , max_batch_tokens ):
@@ -89,10 +89,10 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max
89
89
90
90
@parameterized .expand (
91
91
[
92
- ("eager_paged " , 64 , 128 , 64 ),
93
- ("sdpa_paged " , 32 , 256 , 128 ),
94
- ("paged_attention " , 16 , 512 , 256 ),
95
- ("flex_paged " , 64 , 128 , 64 ),
92
+ ("paged|eager " , 64 , 128 , 64 ),
93
+ ("paged|sdpa " , 32 , 256 , 128 ),
94
+ ("paged|flash_attention_2 " , 16 , 512 , 256 ),
95
+ ("paged|flex_attention " , 64 , 128 , 64 ),
96
96
]
97
97
)
98
98
def test_generate_batch_with_sampling (self , attn_impl , num_blocks , block_size , max_batch_tokens ):
0 commit comments