Skip to content

Commit 080bb43

Browse files
committed
fix all
1 parent b989607 commit 080bb43

File tree

4 files changed

+82
-269
lines changed

4 files changed

+82
-269
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
from .impl import ChunkedPrefillBackend
3+
from typing import List
4+
from lightllm.server.router.model_infer.infer_batch import InferReq
5+
from lightllm.server.router.model_infer.mode_backend.pre import prepare_prefill_inputs
6+
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
7+
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack
8+
9+
10+
class ReturnPromptLogProbBackend(ChunkedPrefillBackend):
11+
def __init__(self) -> None:
12+
super().__init__()
13+
self.prefill = self.return_all_prompt_logprobs_prefill
14+
return
15+
16+
def return_all_prompt_logprobs_prefill(
17+
self,
18+
event_pack: OverlapEventPack,
19+
prefill_reqs: List[InferReq]):
20+
21+
# 在 return all_prompt_logprobs 的模式下,不能启用 dynamic prompt cache
22+
assert self.radix_cache is None
23+
assert self.disable_chunked_prefill is True
24+
25+
model_input, run_reqs = prepare_prefill_inputs(
26+
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
27+
)
28+
29+
model_output = self.model.forward(model_input)
30+
prompt_all_logits = model_output.logits
31+
32+
input_ids = model_input.input_ids
33+
b_ready_cache_len = model_input.b_ready_cache_len
34+
b_seq_len = model_input.b_seq_len
35+
last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1
36+
logits = prompt_all_logits[last_index, :]
37+
38+
b_q_seq_len = b_seq_len - b_ready_cache_len
39+
b_start_loc = torch.cumsum(b_q_seq_len, dim=0, dtype=torch.long) - b_q_seq_len
40+
b_start_loc = b_start_loc.cpu().numpy()
41+
b_q_seq_len = b_q_seq_len.cpu().numpy()
42+
43+
for req_obj, start_loc, q_seq_len in zip(run_reqs, b_start_loc, b_q_seq_len):
44+
req_obj: InferReq = req_obj
45+
cur_ids: torch.Tensor = input_ids[start_loc : start_loc + q_seq_len]
46+
cur_logits = prompt_all_logits[start_loc : start_loc + q_seq_len]
47+
cur_logprobs = torch.log_softmax(cur_logits, dim=-1, dtype=torch.float)[0:-1, :]
48+
cur_logprobs = torch.gather(cur_logprobs, dim=1, index=cur_ids[1:].view(-1, 1)).detach().cpu().numpy()
49+
50+
if req_obj.shm_req.input_len > 1:
51+
req_obj.shm_req.shm_logprobs.arr[1 : req_obj.shm_req.input_len] = cur_logprobs.flatten()
52+
53+
if self.prefill_mask_func is not None:
54+
self.prefill_mask_func(run_reqs, logits)
55+
56+
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
57+
next_token_ids = next_token_ids.detach().cpu().numpy()
58+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
59+
60+
update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
61+
self._post_handle(
62+
run_reqs=run_reqs,
63+
next_token_ids=next_token_ids,
64+
next_token_logprobs=next_token_logprobs,
65+
run_reqs_update_packs=update_packs,
66+
extra_post_req_handle_func=self.extra_post_req_handle_func,
67+
)
68+
return

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py renamed to lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_reward_model.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,24 @@
11
import torch
22
from typing import List, Tuple
3-
from .impl import ContinuesBatchBackend
4-
from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context
3+
from .impl import ChunkedPrefillBackend
4+
from lightllm.server.router.model_infer.infer_batch import InferReq
55
from lightllm.server.router.model_infer.mode_backend.pre import prepare_prefill_inputs
6-
from lightllm.server.core.objs import FinishStatus
6+
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack
77

8-
9-
class RewardModelBackend(ContinuesBatchBackend):
8+
class RewardModelBackend(ChunkedPrefillBackend):
109
def __init__(self) -> None:
1110
super().__init__()
1211

13-
def decode(self):
14-
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
15-
g_infer_context.infer_req_ids
16-
)
17-
18-
if aborted_reqs:
19-
g_infer_context.filter_reqs(aborted_reqs)
20-
21-
if prefill_reqs:
22-
self._prefill_reqs(req_objs=prefill_reqs)
23-
24-
if decode_reqs:
25-
self.normal_decode(decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs)
26-
27-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
12+
self.prefill = self.reward_prefill
2813
return
2914

30-
def _prefill_reqs(self, req_objs: List[InferReq]):
15+
def reward_prefill(self,
16+
event_pack: OverlapEventPack,
17+
prefill_reqs: List[InferReq]):
18+
19+
assert self.disable_chunked_prefill is True
3120
model_input, run_reqs = prepare_prefill_inputs(
32-
req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
21+
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
3322
)
3423

3524
model_output = self.model.forward(model_input)
@@ -39,20 +28,14 @@ def _prefill_reqs(self, req_objs: List[InferReq]):
3928
next_token_id = 1
4029
next_token_logprob = 1.0
4130

42-
finished_req_ids = []
43-
4431
for req_obj, score in zip(run_reqs, scores):
4532
# prefill and decode is same
4633
req_obj: InferReq = req_obj
4734
req_obj.cur_kv_len = req_obj.get_cur_total_len()
48-
49-
req_obj.set_next_gen_token_id(next_token_id, next_token_logprob)
35+
5036
req_obj.cur_output_len += 1
51-
52-
req_obj.update_finish_status(self.eos_id)
53-
54-
if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted:
55-
finished_req_ids.append(req_obj.shm_req.request_id)
37+
req_obj.set_next_gen_token_id(next_token_id, next_token_logprob, output_len=req_obj.cur_output_len)
38+
req_obj.update_finish_status(self.eos_id, output_len=req_obj.cur_output_len)
5639

5740
if self.is_master_in_dp:
5841
# 写入 reward_score
@@ -69,6 +52,4 @@ def _prefill_reqs(self, req_objs: List[InferReq]):
6952
req_obj.shm_req.finish_status = req_obj.finish_status
7053

7154
req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len
72-
73-
g_infer_context.filter(finished_req_ids)
7455
return

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py

Lines changed: 0 additions & 136 deletions
This file was deleted.

0 commit comments

Comments
 (0)