@@ -25,47 +25,47 @@ def cdiv_fn(x, y):
25
25
26
26
@triton .jit
27
27
def kernel_paged_attention_2d (
28
- output_ptr , # [num_tokens, num_query_heads, head_size]
29
- query_ptr , # [num_tokens, num_query_heads, head_size]
30
- key_cache_ptr , # [num_blks, num_kv_heads, head_size // x, blk_size, x]
31
- value_cache_ptr , # [num_blks, num_kv_heads, head_size, blk_size]
32
- sink_ptr , # [num_query_heads]
33
- block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
34
- seq_lens_ptr , # [num_seqs]
35
- alibi_slopes_ptr , # [num_query_heads]
36
- scale , # float32
37
- k_scale , # float32
38
- v_scale , # float32
39
- out_scale ,
40
- num_query_heads : tl .constexpr , # int
41
- num_queries_per_kv : tl .constexpr , # int
42
- num_queries_per_kv_padded : tl .constexpr , # int
43
- block_table_stride : tl .int64 , # int
44
- query_stride_0 : tl .int64 , # int
45
- query_stride_1 : tl .int64 , # int, should be equal to head_size
46
- output_stride_0 : tl .int64 , # int
47
- output_stride_1 : tl .int64 , # int, should be equal to head_size
48
- BLOCK_SIZE : tl .constexpr , # int
49
- HEAD_SIZE : tl .constexpr , # int
50
- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
51
- USE_ALIBI_SLOPES : tl .constexpr , # bool
52
- SLIDING_WINDOW : tl .constexpr , # int
53
- x : tl .constexpr , # int
54
- stride_k_cache_0 : tl .int64 , # int
55
- stride_k_cache_1 : tl .int64 , # int
56
- stride_k_cache_2 : tl .int64 , # int
57
- stride_k_cache_3 : tl .int64 , # int
58
- stride_k_cache_4 : tl .int64 , # int
59
- stride_v_cache_0 : tl .int64 , # int
60
- stride_v_cache_1 : tl .int64 , # int
61
- stride_v_cache_2 : tl .int64 , # int
62
- stride_v_cache_3 : tl .int64 , # int
63
- filter_by_query_len : tl .constexpr , # bool
64
- query_start_len_ptr , # [num_seqs+1]
65
- USE_FP8 : tl .constexpr ,
66
- USE_SINKS : tl .constexpr , # bool
67
- FP8_MIN : tl .constexpr = float8_info .min ,
68
- FP8_MAX : tl .constexpr = float8_info .max ,
28
+ output_ptr , # [num_tokens, num_query_heads, head_size]
29
+ query_ptr , # [num_tokens, num_query_heads, head_size]
30
+ key_cache_ptr , # [num_blks, num_kv_heads, head_size // x, blk_size, x]
31
+ value_cache_ptr , # [num_blks, num_kv_heads, head_size, blk_size]
32
+ sink_ptr , # [num_query_heads]
33
+ block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
34
+ seq_lens_ptr , # [num_seqs]
35
+ alibi_slopes_ptr , # [num_query_heads]
36
+ scale , # float32
37
+ k_scale , # float32
38
+ v_scale , # float32
39
+ out_scale ,
40
+ num_query_heads : tl .constexpr , # int
41
+ num_queries_per_kv : tl .constexpr , # int
42
+ num_queries_per_kv_padded : tl .constexpr , # int
43
+ block_table_stride : tl .int64 , # int
44
+ query_stride_0 : tl .int64 , # int
45
+ query_stride_1 : tl .int64 , # int, should be equal to head_size
46
+ output_stride_0 : tl .int64 , # int
47
+ output_stride_1 : tl .int64 , # int, should be equal to head_size
48
+ BLOCK_SIZE : tl .constexpr , # int
49
+ HEAD_SIZE : tl .constexpr , # int
50
+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
51
+ USE_ALIBI_SLOPES : tl .constexpr , # bool
52
+ SLIDING_WINDOW : tl .constexpr , # int
53
+ x : tl .constexpr , # int
54
+ stride_k_cache_0 : tl .int64 , # int
55
+ stride_k_cache_1 : tl .int64 , # int
56
+ stride_k_cache_2 : tl .int64 , # int
57
+ stride_k_cache_3 : tl .int64 , # int
58
+ stride_k_cache_4 : tl .int64 , # int
59
+ stride_v_cache_0 : tl .int64 , # int
60
+ stride_v_cache_1 : tl .int64 , # int
61
+ stride_v_cache_2 : tl .int64 , # int
62
+ stride_v_cache_3 : tl .int64 , # int
63
+ filter_by_query_len : tl .constexpr , # bool
64
+ query_start_len_ptr , # [num_seqs+1]
65
+ USE_FP8 : tl .constexpr ,
66
+ USE_SINKS : tl .constexpr , # bool
67
+ FP8_MIN : tl .constexpr = float8_info .min ,
68
+ FP8_MAX : tl .constexpr = float8_info .max ,
69
69
):
70
70
seq_idx = tl .program_id (0 )
71
71
kv_head_idx = tl .program_id (1 )
0 commit comments