Skip to content

Commit fd07d79

Browse files
author
niushengxiao
committed
feat: add page_size_variable mode for fa3 backend
1 parent 6095789 commit fd07d79

File tree

11 files changed

+790
-26
lines changed

11 files changed

+790
-26
lines changed

lightllm/common/mem_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
55
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
66
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
7+
from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager
78
from lightllm.utils.log_utils import init_logger
89

910
logger = init_logger(__name__)
@@ -28,6 +29,9 @@ def select_mem_manager_class(mode):
2829
elif "export_fp8kv_calibration" in mode:
2930
memory_manager_class = ExportCalibrationMemoryManager
3031
logger.info("Using mode export fp8kv calibration")
32+
elif "page_size_variable" in mode:
33+
memory_manager_class = PageSizeVariableMemoryManager
34+
logger.info("Page size will be variable")
3135
else:
3236
memory_manager_class = MemoryManager
3337
logger.info("Model kv cache using mode normal")
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import torch
2+
import numpy as np
3+
from .mem_manager import MemoryManager
4+
from typing import List, Union
5+
from lightllm.utils.log_utils import init_logger
6+
from lightllm.utils.envs_utils import get_page_size
7+
from lightllm.common.infer_utils import init_req_to_token_indexes
8+
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
9+
10+
11+
def cdiv(a, b):
12+
return (a + b - 1) // b
13+
14+
15+
logger = init_logger(__name__)
16+
17+
18+
class PageSizeVariableMemoryManager(MemoryManager):
19+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
20+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
21+
self.req_to_page_indexs = None
22+
page_size = get_page_size()
23+
self.page_idx_pool = torch.arange(
24+
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
25+
)
26+
self.mark_page_start = 0
27+
self.can_use_page_size = cdiv(self.size, page_size)
28+
29+
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
30+
self.kv_buffer = torch.empty(
31+
(layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim),
32+
dtype=dtype,
33+
device="cuda",
34+
)
35+
36+
# 要求长度必须是page_size的整数倍,page内token索引必须连续
37+
def check_cache_page_valid(self, values: torch.Tensor):
38+
end = len(values)
39+
assert end % self.page_size == 0, "Values length must be a multiple of page size"
40+
total_pages = end // self.page_size
41+
for page_idx in range(total_pages):
42+
values_start = page_idx * self.page_size
43+
values_end = min((page_idx + 1) * self.page_size, end)
44+
page_token_idxs = values[values_start:values_end]
45+
if len(page_token_idxs) > 1:
46+
expected_idxs = torch.arange(
47+
page_token_idxs[0],
48+
page_token_idxs[0] + len(page_token_idxs),
49+
dtype=page_token_idxs.dtype,
50+
device=page_token_idxs.device,
51+
)
52+
if not torch.equal(page_token_idxs, expected_idxs):
53+
return False
54+
return True
55+
56+
def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
57+
# assert self.check_cache_page_valid(values), "Values must be valid for page size"
58+
page_size = get_page_size()
59+
self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size
60+
self.req_to_token_indexs[req_idx, start:end] = values
61+
62+
def expand_by_page_size(self, b_token_len, page_size):
63+
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4
64+
b_page_len = cdiv(b_token_len, page_size)
65+
need_pages_num = b_page_len.sum()
66+
p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device)
67+
cumsum_pages = torch.cumsum(b_page_len, dim=0)
68+
last_page_positions = cumsum_pages - 1
69+
remainders = b_token_len - (b_page_len - 1) * page_size
70+
p_token_len[last_page_positions] = remainders
71+
return need_pages_num, b_page_len, p_token_len
72+
73+
def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill):
74+
if is_prefill:
75+
b_req_idx = b_req_idx.cuda()
76+
b_seq_len = b_seq_len.cuda()
77+
b_ready_cache_len = b_ready_cache_len.cuda()
78+
79+
b_token_len = b_seq_len - b_ready_cache_len
80+
total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size)
81+
if self.can_use_page_size < total_pages_needed:
82+
raise RuntimeError(
83+
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}"
84+
)
85+
86+
allocated_pages = self.page_idx_pool[
87+
self.mark_page_start : self.mark_page_start + total_pages_needed
88+
].cuda()
89+
90+
def get_offsets_by_length(b_len, max_len):
91+
# 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4]
92+
offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device)
93+
offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1)
94+
return torch.masked_select(offsets, offset_mask)
95+
96+
page_offsets = get_offsets_by_length(b_page_len, b_page_len.max())
97+
token_offsets = get_offsets_by_length(p_token_len, page_size)
98+
99+
# 更新req_to_page_indexs, b_ready_cache_len必整除page_size
100+
page_starts = b_ready_cache_len // page_size
101+
req_id = torch.repeat_interleave(
102+
torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len
103+
)
104+
self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages
105+
106+
self.mark_page_start += total_pages_needed
107+
self.can_use_page_size -= total_pages_needed
108+
page_bases = allocated_pages * page_size
109+
return torch.repeat_interleave(page_bases, p_token_len) + token_offsets
110+
else:
111+
b_seq_len = b_seq_len.cuda()
112+
b_req_idx = b_req_idx.cuda()
113+
need_new_page_mask = (b_seq_len - 1) % page_size == 0
114+
new_pages_num = need_new_page_mask.sum()
115+
if self.can_use_page_size < new_pages_num:
116+
raise RuntimeError(
117+
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}"
118+
)
119+
120+
token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device)
121+
if new_pages_num > 0:
122+
new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda()
123+
self.mark_page_start += new_pages_num
124+
self.can_use_page_size -= new_pages_num
125+
token_idxs[need_new_page_mask] = new_pages * page_size
126+
127+
# 需要更新req_to_page_indexs
128+
new_page_req_indices = b_req_idx[need_new_page_mask]
129+
page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size
130+
self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages
131+
132+
mask = ~need_new_page_mask
133+
if mask.any():
134+
seq_lens = b_seq_len[mask]
135+
token_idxs[mask] = (
136+
self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size
137+
+ (seq_lens - 1) % page_size
138+
)
139+
return token_idxs
140+
141+
def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor:
142+
page_size = get_page_size()
143+
token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill)
144+
self.can_use_mem_size -= need_size
145+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
146+
147+
if self.req_to_token_indexs is not None:
148+
assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided"
149+
if is_prefill:
150+
init_req_to_token_indexes(
151+
self.req_to_token_indexs,
152+
b_req_idx,
153+
b_seq_len,
154+
b_ready_cache_len,
155+
token_idxs,
156+
)
157+
else:
158+
copy_kv_index_to_req(
159+
self.req_to_token_indexs,
160+
b_req_idx.cuda(),
161+
b_seq_len.cuda(),
162+
token_idxs.cuda(),
163+
)
164+
return token_idxs
165+
166+
def free(self, free_index: Union[torch.Tensor, List[int]]):
167+
self.can_use_mem_size += len(free_index)
168+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
169+
170+
page_size = get_page_size()
171+
if isinstance(free_index, list):
172+
free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True)
173+
174+
if len(free_index) == 0:
175+
return
176+
177+
page_indices = free_index // page_size
178+
unique_pages = torch.unique(page_indices)
179+
for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序
180+
self.mark_page_start -= 1
181+
self.page_idx_pool[self.mark_page_start] = page_idx
182+
self.can_use_page_size += 1
183+
184+
return
185+
186+
def free_all(self):
187+
super().free_all()
188+
page_size = get_page_size()
189+
self.mark_page_start = 0
190+
self.can_use_page_size = cdiv(self.size, page_size)
191+
self.page_idx_pool = torch.arange(
192+
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
193+
)

lightllm/common/req_manager.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List, Optional
66
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
77
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
8-
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
8+
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size
99
from lightllm.utils.config_utils import get_vocab_size
1010

1111
logger = init_logger(__name__)
@@ -63,6 +63,14 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
6363
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
6464
)
6565
mem_manager.req_to_token_indexs = self.req_to_token_indexs
66+
if hasattr(mem_manager, "req_to_page_indexs"):
67+
page_size = get_page_size()
68+
self.req_to_page_indexs = torch.zeros(
69+
(max_request_num + 1, (max_sequence_length + page_size - 1) // page_size),
70+
dtype=torch.int32,
71+
device="cuda",
72+
)
73+
mem_manager.req_to_page_indexs = self.req_to_page_indexs
6674
self.mem_manager = mem_manager
6775
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
6876
self.max_request_num = max_request_num

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6-
from lightllm.utils.envs_utils import get_env_start_args
6+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
77
from lightllm.utils.dist_utils import get_current_device_id
88
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
99
from lightllm.common.basemodel.batch_objs import ModelInput
1010

1111

12+
def cdiv(a, b):
13+
return (a + b - 1) // b
14+
15+
1216
class FlashAttentionStateInfo(LlamaInferStateInfo):
1317
_shared_page_table_buffer = None
1418

@@ -29,32 +33,34 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2933
if self.is_prefill:
3034
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3135
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
32-
self.page_table = torch.empty(
33-
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device
34-
)
35-
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
36+
length = cdiv(self.max_seq_len, get_page_size())
37+
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device)
38+
if "page_size_variable" in model.mode:
39+
self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
40+
else:
41+
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length])
3642
else:
3743
# Meta information of flashattention for decoding
3844
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3945
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
4046
max_seq_len_k = self.max_kv_seq_len
4147
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
42-
page_buffer = FlashAttentionStateInfo.get_page_table_buffer(
43-
model.graph_max_batch_size, model.graph_max_len_in_batch
48+
page_size = get_page_size()
49+
length = cdiv(model.graph_max_len_in_batch, page_size)
50+
page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
51+
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
52+
self.batch_size, length
4453
)
45-
self.page_table = page_buffer[self.microbatch_index][
46-
: self.batch_size * model.graph_max_len_in_batch
47-
].reshape(self.batch_size, model.graph_max_len_in_batch)
4854
else:
49-
self.page_table = torch.empty(
50-
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
51-
)
55+
length = cdiv(self.max_len_in_batch, get_page_size())
56+
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device)
5257

53-
self.page_table[:, :max_seq_len_k].copy_(
54-
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k],
55-
non_blocking=True,
56-
)
57-
self.page_table[:, max_seq_len_k:].fill_(0)
58+
length = cdiv(max_seq_len_k, get_page_size())
59+
if "page_size_variable" in model.mode:
60+
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
61+
else:
62+
self.page_table[:, :length].copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length])
63+
self.page_table[:, length:].fill_(0)
5864

5965
if "offline_calibration_fp8kv" in model.mode:
6066
if self.is_prefill:

0 commit comments

Comments
 (0)