99import json
1010
1111
12- def format_tgi_params (params ):
12+ def format_tgi_params (params , num_beam : int ):
1313 """
1414 tgi params format -> lightllm server params format
1515 pub(crate) struct GenerateParameters {
@@ -40,7 +40,7 @@ def format_tgi_params(params):
4040 if "stop_sequences" not in params :
4141 params ["stop_sequences" ] = params .pop ("stop" , None )
4242 # remove keys lightllm not used
43- # params.pop( "best_of", 1)
43+ params [ "best_of" ] = num_beam
4444 params .pop ("typical_p" , 0.0 )
4545 params .pop ("return_full_text" , False )
4646 params .pop ("stop" , None )
@@ -49,14 +49,17 @@ def format_tgi_params(params):
4949 params .pop ("details" , False )
5050 params .pop ("decoder_input_details" , False )
5151 params .pop ("seed" , 0 )
52+ params .pop ("token_healing_top_k" , 0 )
53+ params .pop ("token_healing_unmerge_last_token" , 0 )
5254 return params
5355
5456
5557async def tgi_generate_impl (request : Request , httpserver_manager : HttpServerManager ) -> Response :
5658
5759 request_dict = await request .json ()
5860 prompt = request_dict .pop ("inputs" )
59- sample_params_dict = format_tgi_params (request_dict ["parameters" ])
61+ num_beam = request_dict .get ("num_beam" , 1 )
62+ sample_params_dict = format_tgi_params (request_dict ["parameters" ], num_beam )
6063 return_details = sample_params_dict .pop ("return_details" , False )
6164 sampling_params = SamplingParams ()
6265 sampling_params .init (tokenizer = httpserver_manager .tokenizer , ** sample_params_dict )
@@ -74,6 +77,8 @@ async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerMana
7477 prompt_logprobs = None
7578 prompt_token_ids = None
7679 is_first_metadata = True
80+ best_score = - float ("inf" )
81+ best_sub_id = 0
7782 async for sub_req_id , request_output , metadata , finish_status in results_generator :
7883 # when set "--return_all_prompt_logprobs", the first token metadata will contains
7984 # prompt_logprobs and prompt_token_ids
@@ -93,27 +98,41 @@ async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerMana
9398 tokens_dict [sub_req_id ].append (metadata )
9499 if finish_status .is_finished ():
95100 finish_status_dict [sub_req_id ] = finish_status
101+ if metadata ["cumlogprob" ] > best_score :
102+ best_score = metadata ["cumlogprob" ]
103+ best_sub_id = sub_req_id
96104
97- rets = []
105+ ret = None
106+ beam_sequences = []
98107 for sub_id in list (final_output_dict .keys ()):
108+ if return_details :
109+ beam_ret = {
110+ "generated_text" : "" .join (final_output_dict [sub_id ]),
111+ "finish_reason" : finish_status_dict [sub_id ].get_finish_reason (),
112+ "generated_tokens" : count_output_tokens_dict [sub_id ],
113+ "logprob" : tokens_dict [sub_id ][- 1 ]["cumlogprob" ],
114+ }
115+ beam_sequences .append (beam_ret )
116+ if sub_id != best_sub_id :
117+ continue
99118 ret = {
100119 "generated_text" : "" .join (final_output_dict [sub_id ]),
101- "count_output_tokens" : count_output_tokens_dict [sub_id ],
102- "finish_reason" : finish_status_dict [sub_id ].get_finish_reason (),
103120 }
104121 if return_details :
105122 ret ["details" ] = {
106- "tokens" : tokens_dict [sub_id ],
107123 "generated_tokens" : count_output_tokens_dict [sub_id ],
108124 "finish_reason" : finish_status_dict [sub_id ].get_finish_reason (),
125+ "tokens" : tokens_dict [sub_id ],
109126 }
110127 if prompt_token_ids is not None :
111128 ret ["prompt_token_ids" ] = prompt_token_ids
112129 if prompt_logprobs is not None :
113130 ret ["prompt_logprobs" ] = prompt_logprobs
114- rets .append (ret )
131+ assert ret is not None
132+ if return_details :
133+ ret ["beam_sequences" ] = beam_sequences
115134 # wrap generation inside a Vec to match api-inference
116- json_compatible_item_data = jsonable_encoder (rets )
135+ json_compatible_item_data = jsonable_encoder ([ ret ] )
117136 return JSONResponse (content = json_compatible_item_data )
118137
119138
0 commit comments