Skip to content

Commit 5cd75ce

Browse files
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
1 parent 5f00002 commit 5cd75ce

14 files changed

+368
-235
lines changed

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,18 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
9090
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
9191

9292
# Physical cache allocation
93-
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
94-
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
95-
self._kv_caches = self._init_device_caches(alloc_shape)
93+
if config.use_cuda_kernel:
94+
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
95+
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
96+
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
97+
self.logger.info(
98+
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
99+
)
100+
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
101+
else:
102+
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
103+
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
104+
self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
96105
self.total_physical_cache_size_in_bytes = (
97106
self.elem_size_in_bytes
98107
* self.num_layers
@@ -479,7 +488,9 @@ def _init_logical_caches(self):
479488
blocks.append(cache_block)
480489
return blocks
481490

482-
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
491+
def _init_device_caches(
492+
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
493+
) -> Tuple[torch.Tensor, torch.Tensor]:
483494
"""Initialize the physical cache on the device.
484495
485496
For each layer of the model, we allocate two tensors for key and value respectively,
@@ -488,6 +499,6 @@ def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tenso
488499
k_cache: List[torch.Tensor] = []
489500
v_cache: List[torch.Tensor] = []
490501
for _ in range(self.num_layers):
491-
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
492-
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
502+
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
503+
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
493504
return k_cache, v_cache

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def forward(
310310
alibi_slopes=self.alibi_slopes,
311311
max_seq_len=kv_seq_len,
312312
sm_scale=sm_scale,
313+
use_new_kcache_layout=use_cuda_kernel,
313314
)
314315
else:
315316
q_len = tokens_to_verify + 1 if is_verifier else 1
@@ -332,6 +333,21 @@ def forward(
332333
inference_ops.decode_kv_cache_memcpy(
333334
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
334335
)
336+
inference_ops.flash_decoding_attention(
337+
output_tensor,
338+
query_states,
339+
k_cache,
340+
v_cache,
341+
sequence_lengths,
342+
block_tables,
343+
block_size,
344+
kv_seq_len,
345+
fd_inter_tensor.mid_output,
346+
fd_inter_tensor.mid_output_lse,
347+
self.alibi_slopes,
348+
sm_scale,
349+
)
350+
attn_output = output_tensor
335351
else:
336352
if not is_verifier and not self.use_alibi_attn:
337353
decoding_fused_rotary_embedding(
@@ -355,21 +371,21 @@ def forward(
355371
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
356372
)
357373

358-
attn_output = flash_decoding_attention(
359-
q=query_states,
360-
k_cache=k_cache,
361-
v_cache=v_cache,
362-
kv_seq_len=sequence_lengths,
363-
block_tables=block_tables,
364-
block_size=block_size,
365-
max_seq_len_in_batch=kv_seq_len,
366-
output=output_tensor,
367-
mid_output=fd_inter_tensor.mid_output,
368-
mid_output_lse=fd_inter_tensor.mid_output_lse,
369-
alibi_slopes=self.alibi_slopes,
370-
sm_scale=sm_scale,
371-
q_len=q_len,
372-
)
374+
attn_output = flash_decoding_attention(
375+
q=query_states,
376+
k_cache=k_cache,
377+
v_cache=v_cache,
378+
kv_seq_len=sequence_lengths,
379+
block_tables=block_tables,
380+
block_size=block_size,
381+
max_seq_len_in_batch=kv_seq_len,
382+
output=output_tensor,
383+
mid_output=fd_inter_tensor.mid_output,
384+
mid_output_lse=fd_inter_tensor.mid_output_lse,
385+
alibi_slopes=self.alibi_slopes,
386+
sm_scale=sm_scale,
387+
q_len=q_len,
388+
)
373389

374390
attn_output = attn_output.view(-1, self.hidden_size)
375391
attn_output = self.o_proj(attn_output)

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,8 @@ def llama_model_forward(
9898
"""
9999
block_tables = inputmetadata.block_tables
100100
sequence_lengths = inputmetadata.sequence_lengths
101-
batch_size = inputmetadata.batch_size
102101
kv_seq_len = inputmetadata.kv_seq_len
103102

104-
# NOTE: After testing, the performance of this configuration is relatively good. With updates
105-
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
106-
# selection should be conducted.
107-
if batch_size >= 32 and kv_seq_len > 512:
108-
use_cuda_kernel = False
109-
110103
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
111104
# during speculative-decoding (`q_len > 1`)
112105
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
@@ -575,6 +568,7 @@ def forward(
575568
output=output_tensor,
576569
max_seq_len=kv_seq_len,
577570
sm_scale=sm_scale,
571+
use_new_kcache_layout=use_cuda_kernel,
578572
)
579573
else:
580574
q_len = tokens_to_verify + 1 if is_verifier else 1
@@ -592,20 +586,21 @@ def forward(
592586
block_tables,
593587
high_precision,
594588
)
595-
# inference_ops.flash_decoding_attention(
596-
# output_tensor,
597-
# query_states,
598-
# k_cache,
599-
# v_cache,
600-
# sequence_lengths,
601-
# block_tables,
602-
# block_size,
603-
# kv_seq_len,
604-
# fd_inter_tensor.mid_output,
605-
# fd_inter_tensor.mid_output_lse,
606-
# sm_scale,
607-
# )
608-
# attn_output = output_tensor
589+
inference_ops.flash_decoding_attention(
590+
output_tensor,
591+
query_states,
592+
k_cache,
593+
v_cache,
594+
sequence_lengths,
595+
block_tables,
596+
block_size,
597+
kv_seq_len,
598+
fd_inter_tensor.mid_output,
599+
fd_inter_tensor.mid_output_lse,
600+
None,
601+
sm_scale,
602+
)
603+
attn_output = output_tensor
609604
else:
610605
if is_verifier:
611606
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
@@ -627,21 +622,21 @@ def forward(
627622
block_tables,
628623
sequence_lengths,
629624
)
630-
attn_output = flash_decoding_attention(
631-
q=query_states,
632-
k_cache=k_cache,
633-
v_cache=v_cache,
634-
kv_seq_len=sequence_lengths,
635-
block_tables=block_tables,
636-
block_size=block_size,
637-
max_seq_len_in_batch=kv_seq_len,
638-
output=output_tensor,
639-
mid_output=fd_inter_tensor.mid_output,
640-
mid_output_lse=fd_inter_tensor.mid_output_lse,
641-
sm_scale=sm_scale,
642-
kv_group_num=self.num_key_value_groups,
643-
q_len=q_len,
644-
)
625+
attn_output = flash_decoding_attention(
626+
q=query_states,
627+
k_cache=k_cache,
628+
v_cache=v_cache,
629+
kv_seq_len=sequence_lengths,
630+
block_tables=block_tables,
631+
block_size=block_size,
632+
max_seq_len_in_batch=kv_seq_len,
633+
output=output_tensor,
634+
mid_output=fd_inter_tensor.mid_output,
635+
mid_output_lse=fd_inter_tensor.mid_output_lse,
636+
sm_scale=sm_scale,
637+
kv_group_num=self.num_key_value_groups,
638+
q_len=q_len,
639+
)
645640

646641
attn_output = attn_output.view(-1, self.hidden_size)
647642
attn_output = self.o_proj(attn_output)

examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
configs = [
2121
triton.testing.Benchmark(
2222
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
23-
x_vals=[2**i for i in range(3, 8)],
23+
x_vals=[2**i for i in range(2, 8)],
2424
line_arg="provider",
2525
line_vals=[
2626
"vllm_paged_decoding_attention",
@@ -113,6 +113,8 @@ def benchmark_flash_decoding_attention(
113113
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
114114
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
115115
sm_scale = 1.0 / (HEAD_SIZE**0.5)
116+
alibi_slopes = None
117+
kv_scale = 1.0
116118

117119
mid_output = torch.empty(
118120
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
@@ -136,6 +138,7 @@ def benchmark_flash_decoding_attention(
136138
max_seq_len_across_batch,
137139
alibi_slopes,
138140
"auto",
141+
kv_scale,
139142
)
140143
elif provider == "triton_flash_decoding_attention":
141144
fn = lambda: flash_decoding_attention(
@@ -164,6 +167,7 @@ def benchmark_flash_decoding_attention(
164167
max_seq_len_across_batch,
165168
mid_output,
166169
mid_output_lse,
170+
alibi_slopes,
167171
sm_scale,
168172
)
169173
else:

examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
from colossalai.kernel.kernel_loader import InferenceOpsLoader
44
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
5-
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
5+
from tests.test_infer.test_ops.triton.kernel_utils import (
6+
mock_alloc_block_table_and_kvcache_v2,
7+
mock_alloc_block_table_and_kvcache_v3,
8+
mock_alloc_single_token,
9+
)
610

711
inference_ops = InferenceOpsLoader().load()
812

@@ -68,11 +72,17 @@ def benchmark_rotary_emb(
6872
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
6973
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
7074
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
75+
x = 16 // torch.tensor([], dtype=dtype).element_size()
76+
new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
77+
new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda")
7178

7279
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
7380
block_tables = mock_alloc_block_table_and_kvcache_v2(
7481
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
7582
)
83+
_ = mock_alloc_block_table_and_kvcache_v3(
84+
k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
85+
)
7686
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
7787
new_q = torch.randn_like(new_k)
7888
new_v = torch.randn_like(new_k)
@@ -94,12 +104,12 @@ def benchmark_rotary_emb(
94104
)
95105
elif provider == "no_fused_cuda_rotary_emb_func":
96106
fn = lambda: [
97-
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
98-
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
107+
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
108+
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
99109
]
100110
elif provider == "fused_cuda_rotary_emb_func":
101111
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
102-
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
112+
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
103113
)
104114
else:
105115
raise ValueError("Undefined provider")

examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from colossalai.kernel.kernel_loader import InferenceOpsLoader
55
from colossalai.kernel.triton import copy_kv_to_blocked_cache
66
from colossalai.utils import get_current_device
7+
from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
78
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
89

910
try:
@@ -68,6 +69,9 @@ def benchmark_kvcache_copy(
6869
elif provider == "triton_copy_func":
6970
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
7071
elif provider == "cuda_copy_func":
72+
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
73+
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
74+
)
7175
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
7276
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
7377
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)

0 commit comments

Comments
 (0)