@@ -184,14 +184,15 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
184184 request_id = request_info ['request_id' ]
185185
186186 kwargs = {'max_new_tokens' : request .max_tokens }
187- for key in ['n' , 'stop' , ' best_of' , 'frequency_penalty' , 'length_penalty' , 'presence_penalty' , 'num_beams' ]:
187+ for key in ['n' , 'best_of' , 'frequency_penalty' , 'length_penalty' , 'presence_penalty' , 'num_beams' ]:
188188 kwargs [key ] = getattr (request , key )
189189 for key in ['temperature' , 'top_k' , 'top_p' , 'repetition_penalty' ]:
190190 new_value = getattr (request , key )
191191 if new_value is None :
192192 kwargs [key ] = getattr (llm_engine .generation_config , key )
193193 else :
194194 kwargs [key ] = new_value
195+ kwargs ['stop' ] = (llm_engine .generation_config .stop or []) + (getattr (request , 'stop' ) or [])
195196
196197 generation_config = VllmGenerationConfig (** kwargs )
197198 if generation_config .use_beam_search and request .stream :
@@ -343,7 +344,7 @@ def __repr__(self) -> str:
343344
344345@torch .inference_mode ()
345346async def inference_pt_async (request : Union [ChatCompletionRequest , CompletionRequest ], raw_request : Request ):
346- global model , template
347+ global model , template , _args
347348 result = await _prepare_request (request )
348349 if isinstance (result , JSONResponse ):
349350 return result
@@ -359,8 +360,13 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
359360 new_value = getattr (request , key )
360361 if new_value is None :
361362 kwargs [key ] = getattr (model .generation_config , key )
363+ if key == 'temperature' :
364+ do_sample = getattr (model .generation_config , 'do_sample' )
365+ if not do_sample :
366+ kwargs [key ] = 0
362367 else :
363368 kwargs [key ] = new_value
369+
364370 if kwargs ['temperature' ] == 0 :
365371 kwargs ['do_sample' ] = False
366372 kwargs ['temperature' ] = 1
@@ -374,7 +380,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
374380 set_generation_config (model , generation_config ) # inplace
375381 model .generation_config = _old_generation_config
376382 request_info ['generation_config' ] = generation_config
377- request_info .update ({'seed' : request .seed , 'stop' : request .stop , 'stream' : request .stream })
383+ stop = (_args .stop_words or []) + (getattr (request , 'stop' ) or [])
384+ request_info .update ({'seed' : request .seed , 'stop' : stop , 'stream' : request .stream })
378385 logger .info (request_info )
379386
380387 created_time = int (time .time ())
@@ -397,7 +404,7 @@ async def _generate_full():
397404 model ,
398405 template ,
399406 ** example ,
400- stop_words = request . stop ,
407+ stop_words = stop ,
401408 generation_config = generation_config ,
402409 generation_info = generation_info ,
403410 ** adapter_kwargs )
@@ -441,7 +448,7 @@ def _generate_stream():
441448 model ,
442449 template ,
443450 ** example ,
444- stop_words = request . stop ,
451+ stop_words = stop ,
445452 generation_config = generation_config ,
446453 generation_info = generation_info ,
447454 ** adapter_kwargs )
0 commit comments