Skip to content

Commit 5c8ce8c

Browse files
fix prefix noquant bug (#725)
1 parent 865eff9 commit 5c8ce8c

File tree

6 files changed

+41
-33
lines changed

6 files changed

+41
-33
lines changed

lightllm/common/int8kv_mem_manager.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ def get_cell_size(self):
1414
) + 2 * self.head_num * self.layer_num * torch._utils._element_size(self.dtype)
1515

1616
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
17-
self.kv_buffer = [
18-
torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)
19-
]
20-
self.scale_buffer = [torch.empty((size, 2 * head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)]
17+
self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda")
18+
self.scale_buffer = torch.empty((layer_num, size, 2 * head_num, 1), dtype=dtype, device="cuda")
2119

2220
def _free_buffers(self):
2321
self.kv_buffer = None
2422
self.scale_buffer = None
23+
24+
def get_index_kv_buffer(self, index):
25+
return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]}
26+
27+
def load_index_kv_buffer(self, index, load_tensor_dict):
28+
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
29+
self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"])

lightllm/common/mem_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ def resize_mem(self, new_size):
295295
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
296296
return
297297

298+
def get_index_kv_buffer(self, index):
299+
return {"kv_buffer": self.kv_buffer[:, index]}
300+
301+
def load_index_kv_buffer(self, index, load_tensor_dict):
302+
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
303+
298304

299305
class ReadOnlyStaticsMemoryManager:
300306
"""

lightllm/common/ppl_int4kv_mem_manager.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@ def get_cell_size(self):
1717
)
1818

1919
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
20-
self.kv_buffer = [
21-
torch.empty((size, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda") for _ in range(layer_num)
22-
]
23-
self.scale_buffer = [
24-
torch.empty((size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda")
25-
for _ in range(layer_num)
26-
]
20+
self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda")
21+
self.scale_buffer = torch.empty(
22+
(layer_num, size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda"
23+
)
2724

2825
def _free_buffers(self):
2926
self.kv_buffer = None
3027
self.scale_buffer = None
28+
29+
def get_index_kv_buffer(self, index):
30+
return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]}
31+
32+
def load_index_kv_buffer(self, index, load_tensor_dict):
33+
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
34+
self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"])

lightllm/common/ppl_int8kv_mem_manager.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@ def get_cell_size(self):
1717
)
1818

1919
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
20-
self.kv_buffer = [
21-
torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)
22-
]
23-
self.scale_buffer = [
24-
torch.empty((size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda")
25-
for _ in range(layer_num)
26-
]
20+
self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda")
21+
self.scale_buffer = torch.empty(
22+
(layer_num, size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda"
23+
)
2724

2825
def _free_buffers(self):
2926
self.kv_buffer = None
3027
self.scale_buffer = None
28+
29+
def get_index_kv_buffer(self, index):
30+
return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]}
31+
32+
def load_index_kv_buffer(self, index, load_tensor_dict):
33+
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
34+
self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"])

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,9 @@ def _save_promptcache_kvbuffer(self):
116116
"""
117117
prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key
118118
print(f"prompt_cache_token_id : {prompt_cache_token_id}")
119-
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
120-
kv_buffer_list = []
121-
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
122-
kv_buffer_list.append(self.radix_cache.mem_manager.kv_buffer[i][: len(prompt_cache_token_id)])
123-
torch.save(kv_buffer_list, f"prompt_cache_rank_{dist.get_rank()}.pt")
124-
else:
125-
torch.save(
126-
self.radix_cache.mem_manager.kv_buffer[:, : len(prompt_cache_token_id)],
127-
f"prompt_cache_rank_{dist.get_rank()}.pt",
128-
)
119+
index = range(len(prompt_cache_token_id))
120+
prompt_cache_kv_buffer = self.radix_cache.mem_manager.get_index_kv_buffer(index)
121+
torch.save(prompt_cache_kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt")
129122

130123
@torch.no_grad()
131124
def filter(self, finished_request_ids: List[int]):

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,7 @@ def preload_prompt_cache_kv_buffer(self, model_cfg):
270270
prompt_cache_kv_buffer = torch.load(prompt_cache_kv_buffer_path, weights_only=True, map_location="cpu")
271271
intact_kv_len = len(model_cfg["prompt_cache_token_ids"])
272272
intact_kv_index = self.radix_cache.mem_manager.alloc(intact_kv_len)
273-
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
274-
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
275-
self.radix_cache.mem_manager.kv_buffer[i][intact_kv_index].copy_(prompt_cache_kv_buffer[i])
276-
else:
277-
self.radix_cache.mem_manager.kv_buffer[:, intact_kv_index].copy_(prompt_cache_kv_buffer)
273+
self.radix_cache.mem_manager.load_index_kv_buffer(intact_kv_index, prompt_cache_kv_buffer)
278274
self.radix_cache.insert(
279275
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"),
280276
intact_kv_index,

0 commit comments

Comments
 (0)