@@ -51,46 +51,46 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
5151
5252@triton .jit
5353def kernel_unified_attention_2d (
54- output_ptr , # [num_tokens, num_query_heads, head_size]
55- query_ptr , # [num_tokens, num_query_heads, head_size]
56- key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
57- value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
58- block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
59- seq_lens_ptr , # [num_seqs]
60- alibi_slopes_ptr , # [num_query_heads]
61- scale , # float32
62- k_scale , # float32
63- v_scale , # float32
64- out_scale , # float32
65- softcap , # float32
66- num_query_heads : tl .constexpr , # int
67- num_queries_per_kv : tl .constexpr , # int
68- block_table_stride : tl .int64 , # int
69- query_stride_0 : tl .int64 , # int
70- query_stride_1 : tl .int64 , # int, should be equal to head_size
71- output_stride_0 : tl .int64 , # int
72- output_stride_1 : tl .int64 , # int, should be equal to head_size
73- BLOCK_SIZE : tl .constexpr , # int
74- HEAD_SIZE : tl .constexpr , # int
75- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
76- USE_ALIBI_SLOPES : tl .constexpr , # bool
77- USE_SOFTCAP : tl .constexpr , # bool
78- SLIDING_WINDOW : tl .constexpr , # int
79- stride_k_cache_0 : tl .int64 , # int
80- stride_k_cache_1 : tl .int64 , # int
81- stride_k_cache_2 : tl .int64 , # int
82- stride_k_cache_3 : tl .constexpr , # int
83- stride_v_cache_0 : tl .int64 , # int
84- stride_v_cache_1 : tl .int64 , # int
85- stride_v_cache_2 : tl .int64 , # int
86- stride_v_cache_3 : tl .constexpr , # int
87- query_start_len_ptr , # [num_seqs+1]
88- BLOCK_Q : tl .constexpr , # int
89- num_seqs : tl .int32 ,
90- BLOCK_M : tl .constexpr , # int
91- USE_FP8 : tl .constexpr , # bool
92- FP8_MIN : tl .constexpr = float8_info .min ,
93- FP8_MAX : tl .constexpr = float8_info .max ,
54+ output_ptr , # [num_tokens, num_query_heads, head_size]
55+ query_ptr , # [num_tokens, num_query_heads, head_size]
56+ key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
57+ value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
58+ block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
59+ seq_lens_ptr , # [num_seqs]
60+ alibi_slopes_ptr , # [num_query_heads]
61+ scale , # float32
62+ k_scale , # float32
63+ v_scale , # float32
64+ out_scale , # float32
65+ softcap , # float32
66+ num_query_heads : tl .constexpr , # int
67+ num_queries_per_kv : tl .constexpr , # int
68+ block_table_stride : tl .int64 , # int
69+ query_stride_0 : tl .int64 , # int
70+ query_stride_1 : tl .int64 , # int, should be equal to head_size
71+ output_stride_0 : tl .int64 , # int
72+ output_stride_1 : tl .int64 , # int, should be equal to head_size
73+ BLOCK_SIZE : tl .constexpr , # int
74+ HEAD_SIZE : tl .constexpr , # int
75+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
76+ USE_ALIBI_SLOPES : tl .constexpr , # bool
77+ USE_SOFTCAP : tl .constexpr , # bool
78+ SLIDING_WINDOW : tl .constexpr , # int
79+ stride_k_cache_0 : tl .int64 , # int
80+ stride_k_cache_1 : tl .int64 , # int
81+ stride_k_cache_2 : tl .int64 , # int
82+ stride_k_cache_3 : tl .constexpr , # int
83+ stride_v_cache_0 : tl .int64 , # int
84+ stride_v_cache_1 : tl .int64 , # int
85+ stride_v_cache_2 : tl .int64 , # int
86+ stride_v_cache_3 : tl .constexpr , # int
87+ query_start_len_ptr , # [num_seqs+1]
88+ BLOCK_Q : tl .constexpr , # int
89+ num_seqs : tl .int32 ,
90+ BLOCK_M : tl .constexpr , # int
91+ USE_FP8 : tl .constexpr , # bool
92+ FP8_MIN : tl .constexpr = float8_info .min ,
93+ FP8_MAX : tl .constexpr = float8_info .max ,
9494):
9595 q_block_global_idx = tl .program_id (0 )
9696 kv_head_idx = tl .program_id (1 )
@@ -487,27 +487,27 @@ def kernel_unified_attention_3d(
487487
488488@triton .jit
489489def reduce_segments (
490- output_ptr , # [num_tokens, num_query_heads, head_size]
491- segm_output_ptr ,
492- #[num_tokens, num_query_heads, max_num_segments, head_size]
493- segm_max_ptr , # [num_tokens, num_query_heads, max_num_segments]
494- segm_expsum_ptr , # [num_tokens, num_query_heads, max_num_segments]
495- seq_lens_ptr , # [num_seqs]
496- num_seqs , # int
497- num_query_heads : tl .constexpr , # int
498- out_scale , # float32
499- output_stride_0 : tl .int64 , # int
500- output_stride_1 : tl .int64 , # int, should be equal to head_size
501- block_table_stride : tl .int64 , # int
502- BLOCK_SIZE : tl .constexpr , # int
503- HEAD_SIZE : tl .constexpr , # int, must be power of 2
504- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
505- query_start_len_ptr , # [num_seqs+1]
506- BLOCK_Q : tl .constexpr , # int
507- NUM_SEGMENTS_PER_SEQ : tl .constexpr , # int
508- USE_FP8 : tl .constexpr , # bool
509- FP8_MIN : tl .constexpr = float8_info .min ,
510- FP8_MAX : tl .constexpr = float8_info .max ,
490+ output_ptr , # [num_tokens, num_query_heads, head_size]
491+ segm_output_ptr ,
492+ #[num_tokens, num_query_heads, max_num_segments, head_size]
493+ segm_max_ptr , # [num_tokens, num_query_heads, max_num_segments]
494+ segm_expsum_ptr , # [num_tokens, num_query_heads, max_num_segments]
495+ seq_lens_ptr , # [num_seqs]
496+ num_seqs , # int
497+ num_query_heads : tl .constexpr , # int
498+ out_scale , # float32
499+ output_stride_0 : tl .int64 , # int
500+ output_stride_1 : tl .int64 , # int, should be equal to head_size
501+ block_table_stride : tl .int64 , # int
502+ BLOCK_SIZE : tl .constexpr , # int
503+ HEAD_SIZE : tl .constexpr , # int, must be power of 2
504+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
505+ query_start_len_ptr , # [num_seqs+1]
506+ BLOCK_Q : tl .constexpr , # int
507+ NUM_SEGMENTS_PER_SEQ : tl .constexpr , # int
508+ USE_FP8 : tl .constexpr , # bool
509+ FP8_MIN : tl .constexpr = float8_info .min ,
510+ FP8_MAX : tl .constexpr = float8_info .max ,
511511):
512512 query_token_idx = tl .program_id (0 )
513513 query_head_idx = tl .program_id (1 )
0 commit comments