1- import os
2- import tempfile
3-
41import openai
52import pytest
63import yaml
@@ -22,34 +19,28 @@ def backend(request):
2219
2320
2421@pytest .fixture (scope = "module" )
25- def temp_extra_llm_api_options_file ():
26- temp_dir = tempfile .gettempdir ()
27- temp_file_path = os .path .join (temp_dir , "extra_llm_api_options.yaml" )
28- try :
29- extra_llm_api_options_dict = {
30- "enable_chunked_prefill" : False ,
31- "gather_generation_logits" : True ,
32- "kv_cache_config" : {
33- "enable_block_reuse" : False ,
34- }
22+ def temp_extra_llm_api_options_file (tmp_path_factory ):
23+ extra_llm_api_options_dict = {
24+ "enable_chunked_prefill" : False ,
25+ "gather_generation_logits" : True ,
26+ "kv_cache_config" : {
27+ "enable_block_reuse" : False ,
3528 }
29+ }
3630
37- with open (temp_file_path , 'w' ) as f :
38- yaml .dump (extra_llm_api_options_dict , f )
39-
40- yield temp_file_path
41- finally :
42- if os .path .exists (temp_file_path ):
43- os .remove (temp_file_path )
31+ temp_file_path = tmp_path_factory .mktemp (
32+ "config" ) / "extra_llm_api_options.yaml"
33+ with open (temp_file_path , 'w' ) as f :
34+ yaml .dump (extra_llm_api_options_dict , f )
35+ return temp_file_path
4436
4537
4638@pytest .fixture (scope = "module" )
4739def server (model_name : str , backend : str , temp_extra_llm_api_options_file : str ):
4840 model_path = get_model_path (model_name )
49- args = [
50- "--backend" , f"{ backend } " , "--extra_llm_api_options" ,
51- temp_extra_llm_api_options_file
52- ]
41+ args = ["--backend" , f"{ backend } " ]
42+ if backend == "trt" :
43+ args += ["--extra_llm_api_options" , temp_extra_llm_api_options_file ]
5344 with RemoteOpenAIServer (model_path , args ) as remote_server :
5445 yield remote_server
5546
@@ -61,11 +52,7 @@ def async_client(server: RemoteOpenAIServer):
6152
6253@pytest .mark .asyncio (loop_scope = "module" )
6354async def test_chat_completion_top5_logprobs (async_client : openai .AsyncOpenAI ,
64- model_name : str , backend : str ):
65- # Skip if backend is PyTorch as it does not support topk logprobs when k > 1
66- if backend == "pytorch" :
67- pytest .skip ("Topk logprobs is not supported" )
68-
55+ model_name : str ):
6956 messages = [{
7057 "role" : "system" ,
7158 "content" : "You are a helpful assistant."
@@ -94,42 +81,3 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
9481 assert logprob_content .bytes is not None
9582 assert logprob_content .top_logprobs is not None
9683 assert len (logprob_content .top_logprobs ) == 5
97-
98-
99- @pytest .mark .asyncio (loop_scope = "module" )
100- async def test_chat_completion_top1_logprobs (async_client : openai .AsyncOpenAI ,
101- model_name : str , backend : str ):
102- # Skip if backend is TRT because it is tested in test_chat_completion_top5_logprobs
103- if backend == "trt" :
104- pytest .skip (
105- "TRT top logprobs is already tested in test_chat_completion_top5_logprobs"
106- )
107-
108- messages = [{
109- "role" : "system" ,
110- "content" : "You are a helpful assistant."
111- }, {
112- "role" : "user" ,
113- "content" : "What is the capital of France?"
114- }]
115- # Test top_logprobs=1
116- chat_completion = await async_client .chat .completions .create (
117- model = model_name ,
118- messages = messages ,
119- max_completion_tokens = 10 ,
120- temperature = 0.0 ,
121- logprobs = True ,
122- top_logprobs = 1 ,
123- extra_body = {
124- "ignore_eos" : True ,
125- })
126- logprobs = chat_completion .choices [0 ].logprobs
127- assert logprobs is not None and logprobs .content is not None
128- assert len (logprobs .content ) == 10
129- for logprob_content in logprobs .content :
130- assert logprob_content .token is not None
131- assert logprob_content .logprob is not None
132- assert logprob_content .bytes is not None
133- assert logprob_content .top_logprobs is not None
134- # Check that the top_logprobs contains only one entry
135- assert len (logprob_content .top_logprobs ) == 1
0 commit comments