Skip to content

Commit 204f6fa

Browse files
committed
fix
1 parent 2d46245 commit 204f6fa

File tree

3 files changed

+84
-71
lines changed

3 files changed

+84
-71
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
7878
req_objs.append(r_obj)
7979

8080
self.infer_req_ids.extend(request_ids)
81+
82+
# 多输出模式下需要将请求添加到各自的组对象 InferReqGroup 中
83+
if get_env_start_args().diverse_mode:
84+
for r_id in request_ids:
85+
req: InferReq = g_infer_context.requests_mapping[r_id]
86+
group_req_id = req.shm_req.group_req_id
87+
if group_req_id not in g_infer_context.group_mapping:
88+
g_infer_context.group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id)
89+
g_infer_context.group_mapping[group_req_id].add_req(r_id)
90+
8191
return req_objs
8292

8393
def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finished: bool):

lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py

Lines changed: 30 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,20 @@
44
g_infer_context,
55
InferReq,
66
InferReqGroup,
7-
InferSamplingParams,
87
)
98
from typing import List, Tuple
10-
from lightllm.utils.log_utils import init_logger
11-
from lightllm.server.tokenizer import get_tokenizer
129
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
1310
from lightllm.server.router.model_infer.mode_backend.pre import (
1411
prepare_prefill_inputs,
15-
prepare_decode_inputs,
1612
)
1713
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
14+
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack
1815

1916

2017
class DiversehBackend(ModeBackend):
2118
def __init__(self) -> None:
2219
super().__init__()
23-
24-
def init_custom(self):
25-
pass
26-
27-
def build_group(self, req_ids: List[int]):
28-
for r_id in req_ids:
29-
req: InferReq = g_infer_context.requests_mapping[r_id]
30-
group_req_id = req.shm_req.group_req_id
31-
if group_req_id not in g_infer_context.group_mapping:
32-
g_infer_context.group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id)
33-
g_infer_context.group_mapping[group_req_id].add_req(r_id)
20+
self.prefill = self.beam_prefill
3421

3522
def diverse_copy(self, groups: List[InferReqGroup]):
3623
batch_idx = []
@@ -46,64 +33,36 @@ def diverse_copy(self, groups: List[InferReqGroup]):
4633
run_reqs.extend(req_group.get_all_reqs())
4734
return batch_idx, run_reqs
4835

49-
def decode(self):
50-
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
51-
g_infer_context.infer_req_ids,
52-
strict_prefill=True,
36+
def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]):
37+
group_reqs = [
38+
g_infer_context.requests_mapping[req.req_id]
39+
for req in prefill_reqs
40+
if convert_sub_id_to_group_id(req.req_id) == req.req_id
41+
]
42+
groups = [
43+
g_infer_context.group_mapping[req.req_id]
44+
for req in prefill_reqs
45+
if convert_sub_id_to_group_id(req.req_id) == req.req_id
46+
]
47+
model_input, group_run_reqs = prepare_prefill_inputs(
48+
group_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
5349
)
50+
model_output = self.model.forward(model_input)
51+
logits = model_output.logits
5452

55-
if aborted_reqs:
56-
g_infer_context.filter_reqs(aborted_reqs)
57-
if prefill_reqs:
58-
group_reqs = [
59-
g_infer_context.requests_mapping[req.req_id]
60-
for req in prefill_reqs
61-
if convert_sub_id_to_group_id(req.req_id) == req.req_id
62-
]
63-
groups = [
64-
g_infer_context.group_mapping[req.req_id]
65-
for req in prefill_reqs
66-
if convert_sub_id_to_group_id(req.req_id) == req.req_id
67-
]
68-
model_input, group_run_reqs = prepare_prefill_inputs(
69-
group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
70-
)
71-
model_output = self.model.forward(model_input)
72-
logits = model_output.logits
73-
74-
uninit_req_ids = [req.req_id for req in uninit_reqs]
75-
self._overlap_req_init_and_filter(
76-
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
77-
)
78-
self.build_group(uninit_req_ids)
79-
batch_idx, run_reqs = self.diverse_copy(groups)
80-
logits = logits[batch_idx]
81-
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
82-
next_token_ids = next_token_ids.detach().cpu().numpy()
83-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
53+
batch_idx, run_reqs = self.diverse_copy(groups)
54+
logits = logits[batch_idx]
8455

85-
self._post_handle(
86-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
87-
)
56+
next_token_ids_gpu, next_token_probs_gpu = sample(model_output.logits, run_reqs, self.eos_id)
57+
next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy()
58+
next_token_logprobs_cpu = torch.log(next_token_probs_gpu).detach().cpu().numpy()
8859

89-
if decode_reqs:
90-
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
91-
model_output = self.model.forward(model_input)
92-
logits = model_output.logits
93-
uninit_req_ids = [req.req_id for req in uninit_reqs]
94-
self._overlap_req_init_and_filter(
95-
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
96-
)
97-
self.build_group(uninit_req_ids)
98-
99-
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
100-
next_token_ids = next_token_ids.detach().cpu().numpy()
101-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
102-
103-
self._post_handle(
104-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
105-
)
106-
uninit_req_ids = [req.req_id for req in uninit_reqs]
107-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
108-
self.build_group(uninit_req_ids)
60+
update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
61+
self._post_handle(
62+
run_reqs=run_reqs,
63+
next_token_ids=next_token_ids_cpu,
64+
next_token_logprobs=next_token_logprobs_cpu,
65+
run_reqs_update_packs=update_packs,
66+
extra_post_req_handle_func=self.extra_post_req_handle_func,
67+
)
10968
return
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import threading
3+
import collections
4+
from typing import List, Dict
5+
6+
7+
class PinMemTensorManager:
8+
def __init__(self):
9+
self.lock = threading.Lock()
10+
self.key_to_tensor_list: Dict[str, List[torch.Tensor]] = collections.defaultdict(list)
11+
self.key_to_alloc_index: Dict[str, int] = {}
12+
13+
def alloc_pin_tensor(self, key: str, size: int, dtype: torch.dtype):
14+
"""
15+
利用4 buffer的 pin mem的cache,加速对pin mem的申请和释放操作。
16+
"""
17+
with self.lock:
18+
if key not in self.key_to_tensor_list:
19+
self.key_to_tensor_list[key].append(
20+
torch.empty(size=(size,), dtype=dtype, device="cpu", pin_memory=True)
21+
)
22+
self.key_to_tensor_list[key].append(
23+
torch.empty(size=(size,), dtype=dtype, device="cpu", pin_memory=True)
24+
)
25+
self.key_to_tensor_list[key].append(
26+
torch.empty(size=(size,), dtype=dtype, device="cpu", pin_memory=True)
27+
)
28+
self.key_to_tensor_list[key].append(
29+
torch.empty(size=(size,), dtype=dtype, device="cpu", pin_memory=True)
30+
)
31+
self.key_to_alloc_index[key] = 0
32+
33+
alloc_index = self.key_to_alloc_index[key]
34+
buff_tensor = self.key_to_tensor_list[key][alloc_index]
35+
if buff_tensor.numel() < size:
36+
self.key_to_tensor_list[key][alloc_index] = torch.empty(
37+
size=(size,), dtype=dtype, device="cpu", pin_memory=True
38+
)
39+
buff_tensor = self.key_to_tensor_list[key][alloc_index]
40+
self.key_to_alloc_index[key] = (alloc_index + 1) % 4
41+
return buff_tensor[0:size]
42+
43+
44+
g_pin_mem_manager = PinMemTensorManager()

0 commit comments

Comments
 (0)