Skip to content

Commit a740b7f

Browse files
authored
update tgi api for beam output (#726)
Co-authored-by: shihaobai <baishihao@sensetime.com>
1 parent 5c8ce8c commit a740b7f

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

lightllm/server/api_tgi.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import json
1010

1111

12-
def format_tgi_params(params):
12+
def format_tgi_params(params, num_beam: int):
1313
"""
1414
tgi params format -> lightllm server params format
1515
pub(crate) struct GenerateParameters {
@@ -40,7 +40,7 @@ def format_tgi_params(params):
4040
if "stop_sequences" not in params:
4141
params["stop_sequences"] = params.pop("stop", None)
4242
# remove keys lightllm not used
43-
# params.pop("best_of", 1)
43+
params["best_of"] = num_beam
4444
params.pop("typical_p", 0.0)
4545
params.pop("return_full_text", False)
4646
params.pop("stop", None)
@@ -49,14 +49,17 @@ def format_tgi_params(params):
4949
params.pop("details", False)
5050
params.pop("decoder_input_details", False)
5151
params.pop("seed", 0)
52+
params.pop("token_healing_top_k", 0)
53+
params.pop("token_healing_unmerge_last_token", 0)
5254
return params
5355

5456

5557
async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerManager) -> Response:
5658

5759
request_dict = await request.json()
5860
prompt = request_dict.pop("inputs")
59-
sample_params_dict = format_tgi_params(request_dict["parameters"])
61+
num_beam = request_dict.get("num_beam", 1)
62+
sample_params_dict = format_tgi_params(request_dict["parameters"], num_beam)
6063
return_details = sample_params_dict.pop("return_details", False)
6164
sampling_params = SamplingParams()
6265
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
@@ -74,6 +77,8 @@ async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerMana
7477
prompt_logprobs = None
7578
prompt_token_ids = None
7679
is_first_metadata = True
80+
best_score = -float("inf")
81+
best_sub_id = 0
7782
async for sub_req_id, request_output, metadata, finish_status in results_generator:
7883
# when set "--return_all_prompt_logprobs", the first token metadata will contains
7984
# prompt_logprobs and prompt_token_ids
@@ -93,27 +98,41 @@ async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerMana
9398
tokens_dict[sub_req_id].append(metadata)
9499
if finish_status.is_finished():
95100
finish_status_dict[sub_req_id] = finish_status
101+
if metadata["cumlogprob"] > best_score:
102+
best_score = metadata["cumlogprob"]
103+
best_sub_id = sub_req_id
96104

97-
rets = []
105+
ret = None
106+
beam_sequences = []
98107
for sub_id in list(final_output_dict.keys()):
108+
if return_details:
109+
beam_ret = {
110+
"generated_text": "".join(final_output_dict[sub_id]),
111+
"finish_reason": finish_status_dict[sub_id].get_finish_reason(),
112+
"generated_tokens": count_output_tokens_dict[sub_id],
113+
"logprob": tokens_dict[sub_id][-1]["cumlogprob"],
114+
}
115+
beam_sequences.append(beam_ret)
116+
if sub_id != best_sub_id:
117+
continue
99118
ret = {
100119
"generated_text": "".join(final_output_dict[sub_id]),
101-
"count_output_tokens": count_output_tokens_dict[sub_id],
102-
"finish_reason": finish_status_dict[sub_id].get_finish_reason(),
103120
}
104121
if return_details:
105122
ret["details"] = {
106-
"tokens": tokens_dict[sub_id],
107123
"generated_tokens": count_output_tokens_dict[sub_id],
108124
"finish_reason": finish_status_dict[sub_id].get_finish_reason(),
125+
"tokens": tokens_dict[sub_id],
109126
}
110127
if prompt_token_ids is not None:
111128
ret["prompt_token_ids"] = prompt_token_ids
112129
if prompt_logprobs is not None:
113130
ret["prompt_logprobs"] = prompt_logprobs
114-
rets.append(ret)
131+
assert ret is not None
132+
if return_details:
133+
ret["beam_sequences"] = beam_sequences
115134
# wrap generation inside a Vec to match api-inference
116-
json_compatible_item_data = jsonable_encoder(rets)
135+
json_compatible_item_data = jsonable_encoder([ret])
117136
return JSONResponse(content=json_compatible_item_data)
118137

119138

lightllm/server/core/objs/req.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class Req(ctypes.Structure):
9191
("can_released_mark", ctypes.c_bool),
9292
# reward_model 使用的变量
9393
("reward_score", ctypes.c_float),
94+
# 请求回复累计概率和
95+
("cumlogprob", ctypes.c_float),
9496
]
9597

9698
def init(
@@ -119,6 +121,7 @@ def init(
119121
self.finish_token_index = -1
120122
self.can_released_mark = False
121123
self.reward_score = math.nan
124+
self.cumlogprob = 0.0
122125
if isinstance(sample_param, SamplingParams):
123126
self.sample_params = sample_param
124127
else:

lightllm/server/httpserver/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ async def _wait_to_token_package(
356356
metadata["prompt_ids"] = prompt_ids
357357

358358
prompt_cache_len = metadata.pop("prompt_cache_len", 0)
359-
360359
if is_first_token:
361360
first_token_cost_ms = (time.time() - start_time) * 1000
362361
is_first_token = False
@@ -474,9 +473,11 @@ async def handle_loop(self):
474473
if not req.out_tokens_queue.is_empty():
475474

476475
text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
476+
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
477477
metadata = {
478478
"id": int(req.shm_prompt_ids.arr[src_index]),
479479
"logprob": float(req.shm_logprobs.arr[src_index]),
480+
"cumlogprob": float(req.cumlogprob) / count_output_tokens,
480481
"special": special,
481482
"count_output_tokens": count_output_tokens,
482483
"prompt_cache_len": req.prompt_cache_len,

0 commit comments

Comments
 (0)