Skip to content

Commit 324b18a

Browse files
committed
linter
1 parent b382964 commit 324b18a

File tree

1 file changed

+61
-61
lines changed

1 file changed

+61
-61
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -51,46 +51,46 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
5151

5252
@triton.jit
5353
def kernel_unified_attention_2d(
54-
output_ptr, # [num_tokens, num_query_heads, head_size]
55-
query_ptr, # [num_tokens, num_query_heads, head_size]
56-
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
57-
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
58-
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
59-
seq_lens_ptr, # [num_seqs]
60-
alibi_slopes_ptr, # [num_query_heads]
61-
scale, # float32
62-
k_scale, # float32
63-
v_scale, # float32
64-
out_scale, # float32
65-
softcap, # float32
66-
num_query_heads: tl.constexpr, # int
67-
num_queries_per_kv: tl.constexpr, # int
68-
block_table_stride: tl.int64, # int
69-
query_stride_0: tl.int64, # int
70-
query_stride_1: tl.int64, # int, should be equal to head_size
71-
output_stride_0: tl.int64, # int
72-
output_stride_1: tl.int64, # int, should be equal to head_size
73-
BLOCK_SIZE: tl.constexpr, # int
74-
HEAD_SIZE: tl.constexpr, # int
75-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
76-
USE_ALIBI_SLOPES: tl.constexpr, # bool
77-
USE_SOFTCAP: tl.constexpr, # bool
78-
SLIDING_WINDOW: tl.constexpr, # int
79-
stride_k_cache_0: tl.int64, # int
80-
stride_k_cache_1: tl.int64, # int
81-
stride_k_cache_2: tl.int64, # int
82-
stride_k_cache_3: tl.constexpr, # int
83-
stride_v_cache_0: tl.int64, # int
84-
stride_v_cache_1: tl.int64, # int
85-
stride_v_cache_2: tl.int64, # int
86-
stride_v_cache_3: tl.constexpr, # int
87-
query_start_len_ptr, # [num_seqs+1]
88-
BLOCK_Q: tl.constexpr, # int
89-
num_seqs: tl.int32,
90-
BLOCK_M: tl.constexpr, # int
91-
USE_FP8: tl.constexpr, # bool
92-
FP8_MIN: tl.constexpr = float8_info.min,
93-
FP8_MAX: tl.constexpr = float8_info.max,
54+
output_ptr, # [num_tokens, num_query_heads, head_size]
55+
query_ptr, # [num_tokens, num_query_heads, head_size]
56+
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
57+
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
58+
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
59+
seq_lens_ptr, # [num_seqs]
60+
alibi_slopes_ptr, # [num_query_heads]
61+
scale, # float32
62+
k_scale, # float32
63+
v_scale, # float32
64+
out_scale, # float32
65+
softcap, # float32
66+
num_query_heads: tl.constexpr, # int
67+
num_queries_per_kv: tl.constexpr, # int
68+
block_table_stride: tl.int64, # int
69+
query_stride_0: tl.int64, # int
70+
query_stride_1: tl.int64, # int, should be equal to head_size
71+
output_stride_0: tl.int64, # int
72+
output_stride_1: tl.int64, # int, should be equal to head_size
73+
BLOCK_SIZE: tl.constexpr, # int
74+
HEAD_SIZE: tl.constexpr, # int
75+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
76+
USE_ALIBI_SLOPES: tl.constexpr, # bool
77+
USE_SOFTCAP: tl.constexpr, # bool
78+
SLIDING_WINDOW: tl.constexpr, # int
79+
stride_k_cache_0: tl.int64, # int
80+
stride_k_cache_1: tl.int64, # int
81+
stride_k_cache_2: tl.int64, # int
82+
stride_k_cache_3: tl.constexpr, # int
83+
stride_v_cache_0: tl.int64, # int
84+
stride_v_cache_1: tl.int64, # int
85+
stride_v_cache_2: tl.int64, # int
86+
stride_v_cache_3: tl.constexpr, # int
87+
query_start_len_ptr, # [num_seqs+1]
88+
BLOCK_Q: tl.constexpr, # int
89+
num_seqs: tl.int32,
90+
BLOCK_M: tl.constexpr, # int
91+
USE_FP8: tl.constexpr, # bool
92+
FP8_MIN: tl.constexpr = float8_info.min,
93+
FP8_MAX: tl.constexpr = float8_info.max,
9494
):
9595
q_block_global_idx = tl.program_id(0)
9696
kv_head_idx = tl.program_id(1)
@@ -487,27 +487,27 @@ def kernel_unified_attention_3d(
487487

488488
@triton.jit
489489
def reduce_segments(
490-
output_ptr, # [num_tokens, num_query_heads, head_size]
491-
segm_output_ptr,
492-
#[num_tokens, num_query_heads, max_num_segments, head_size]
493-
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
494-
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
495-
seq_lens_ptr, # [num_seqs]
496-
num_seqs, # int
497-
num_query_heads: tl.constexpr, # int
498-
out_scale, # float32
499-
output_stride_0: tl.int64, # int
500-
output_stride_1: tl.int64, # int, should be equal to head_size
501-
block_table_stride: tl.int64, # int
502-
BLOCK_SIZE: tl.constexpr, # int
503-
HEAD_SIZE: tl.constexpr, # int, must be power of 2
504-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
505-
query_start_len_ptr, # [num_seqs+1]
506-
BLOCK_Q: tl.constexpr, # int
507-
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
508-
USE_FP8: tl.constexpr, # bool
509-
FP8_MIN: tl.constexpr = float8_info.min,
510-
FP8_MAX: tl.constexpr = float8_info.max,
490+
output_ptr, # [num_tokens, num_query_heads, head_size]
491+
segm_output_ptr,
492+
#[num_tokens, num_query_heads, max_num_segments, head_size]
493+
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
494+
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
495+
seq_lens_ptr, # [num_seqs]
496+
num_seqs, # int
497+
num_query_heads: tl.constexpr, # int
498+
out_scale, # float32
499+
output_stride_0: tl.int64, # int
500+
output_stride_1: tl.int64, # int, should be equal to head_size
501+
block_table_stride: tl.int64, # int
502+
BLOCK_SIZE: tl.constexpr, # int
503+
HEAD_SIZE: tl.constexpr, # int, must be power of 2
504+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
505+
query_start_len_ptr, # [num_seqs+1]
506+
BLOCK_Q: tl.constexpr, # int
507+
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
508+
USE_FP8: tl.constexpr, # bool
509+
FP8_MIN: tl.constexpr = float8_info.min,
510+
FP8_MAX: tl.constexpr = float8_info.max,
511511
):
512512
query_token_idx = tl.program_id(0)
513513
query_head_idx = tl.program_id(1)

0 commit comments

Comments
 (0)