@@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
6868 temp_extra_llm_api_options_file : str , num_postprocess_workers : int ):
6969 model_path = get_model_path (model_name )
7070 args = ["--backend" , f"{ backend } " ]
71+ args .extend (["--kv_cache_free_gpu_memory_fraction" ,
72+ "0.2" ]) # for co-existence with other servers
7173 if backend == "trt" :
7274 args .extend (["--max_beam_width" , "4" ])
7375 if extra_llm_api_options :
@@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
7880 yield remote_server
7981
8082
83+ @pytest .fixture (scope = "module" )
84+ def server_with_beam_search (model_name : str , backend : str ,
85+ extra_llm_api_options : bool ,
86+ temp_extra_llm_api_options_file : str ,
87+ num_postprocess_workers : int ):
88+ model_path = get_model_path (model_name )
89+ args = ["--backend" , f"{ backend } " ]
90+ args .extend (["--kv_cache_free_gpu_memory_fraction" ,
91+ "0.2" ]) # for co-existence with other servers
92+ args .extend (["--max_beam_width" , "2" ])
93+ if extra_llm_api_options :
94+ args .extend (
95+ ["--extra_llm_api_options" , temp_extra_llm_api_options_file ])
96+ args .extend (["--num_postprocess_workers" , f"{ num_postprocess_workers } " ])
97+ with RemoteOpenAIServer (model_path , args ) as remote_server :
98+ yield remote_server
99+
100+
81101@pytest .fixture (scope = "module" )
82102def client (server : RemoteOpenAIServer ):
83103 return server .get_client ()
84104
85105
106+ @pytest .fixture (scope = "module" )
107+ def client_with_beam_search (server_with_beam_search : RemoteOpenAIServer ):
108+ return server_with_beam_search .get_client ()
109+
110+
86111@pytest .fixture (scope = "module" )
87112def async_client (server : RemoteOpenAIServer ):
88113 return server .get_async_client ()
@@ -180,7 +205,33 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
180205 backend : str ):
181206 if backend == "pytorch" :
182207 pytest .skip (
183- "Multiple responses are not supported in PyTorch backend yet" )
208+ "'n' not allowed with temperature=0 unless TLLM_ALLOW_N_GREEDY_DECODING=1"
209+ )
210+ messages = [{
211+ "role" : "system" ,
212+ "content" : "you are a helpful assistant"
213+ }, {
214+ "role" : "user" ,
215+ "content" : "what is 1+1?"
216+ }]
217+ # test n and best_of
218+ chat_completion = client .chat .completions .create (
219+ model = model_name ,
220+ messages = messages ,
221+ max_completion_tokens = 10 ,
222+ n = 2 ,
223+ temperature = 0.0 ,
224+ extra_body = dict (best_of = 4 ),
225+ )
226+ assert len (chat_completion .choices ) == 2
227+
228+
229+ def test_multiple_responses_and_beam_search (client : openai .OpenAI ,
230+ model_name : str , backend : str ):
231+ if backend == "pytorch" :
232+ pytest .skip (
233+ "Mixing beam search and regular requests is not supported in PyTorch backend"
234+ )
184235
185236 messages = [{
186237 "role" : "system" ,
@@ -202,6 +253,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
202253 assert chat_completion .choices [
203254 0 ].message .content != chat_completion .choices [
204255 1 ].message .content , "beam search should be different"
256+
205257 # test n and best_of
206258 chat_completion = client .chat .completions .create (
207259 model = model_name ,
@@ -214,6 +266,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
214266 assert len (chat_completion .choices ) == 2
215267
216268
269+ def test_multiple_responses_with_beam_search (
270+ client_with_beam_search : openai .OpenAI , model_name : str ):
271+ messages = [{
272+ "role" : "system" ,
273+ "content" : "you are a helpful assistant"
274+ }, {
275+ "role" : "user" ,
276+ "content" : "what is 1+1?"
277+ }]
278+ # test beam search
279+ chat_completion = client_with_beam_search .chat .completions .create (
280+ model = model_name ,
281+ messages = messages ,
282+ max_completion_tokens = 10 ,
283+ n = 2 ,
284+ temperature = 0.0 ,
285+ extra_body = dict (use_beam_search = True ),
286+ )
287+ assert len (chat_completion .choices ) == 2
288+ assert chat_completion .choices [
289+ 0 ].message .content != chat_completion .choices [
290+ 1 ].message .content , "beam search should be different"
291+
292+
217293@pytest .mark .asyncio (loop_scope = "module" )
218294async def test_chat_streaming (async_client : openai .AsyncOpenAI ,
219295 model_name : str ):
0 commit comments