Skip to content

Commit 0f308e9

Browse files
authored
[None][chore] Remove logprobs constraint on trtllm-serve pytorch backend (#9911)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent a6a8898 commit 0f308e9

File tree

1 file changed

+16
-68
lines changed

1 file changed

+16
-68
lines changed
Lines changed: 16 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import os
2-
import tempfile
3-
41
import openai
52
import pytest
63
import yaml
@@ -22,34 +19,28 @@ def backend(request):
2219

2320

2421
@pytest.fixture(scope="module")
25-
def temp_extra_llm_api_options_file():
26-
temp_dir = tempfile.gettempdir()
27-
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
28-
try:
29-
extra_llm_api_options_dict = {
30-
"enable_chunked_prefill": False,
31-
"gather_generation_logits": True,
32-
"kv_cache_config": {
33-
"enable_block_reuse": False,
34-
}
22+
def temp_extra_llm_api_options_file(tmp_path_factory):
23+
extra_llm_api_options_dict = {
24+
"enable_chunked_prefill": False,
25+
"gather_generation_logits": True,
26+
"kv_cache_config": {
27+
"enable_block_reuse": False,
3528
}
29+
}
3630

37-
with open(temp_file_path, 'w') as f:
38-
yaml.dump(extra_llm_api_options_dict, f)
39-
40-
yield temp_file_path
41-
finally:
42-
if os.path.exists(temp_file_path):
43-
os.remove(temp_file_path)
31+
temp_file_path = tmp_path_factory.mktemp(
32+
"config") / "extra_llm_api_options.yaml"
33+
with open(temp_file_path, 'w') as f:
34+
yaml.dump(extra_llm_api_options_dict, f)
35+
return temp_file_path
4436

4537

4638
@pytest.fixture(scope="module")
4739
def server(model_name: str, backend: str, temp_extra_llm_api_options_file: str):
4840
model_path = get_model_path(model_name)
49-
args = [
50-
"--backend", f"{backend}", "--extra_llm_api_options",
51-
temp_extra_llm_api_options_file
52-
]
41+
args = ["--backend", f"{backend}"]
42+
if backend == "trt":
43+
args += ["--extra_llm_api_options", temp_extra_llm_api_options_file]
5344
with RemoteOpenAIServer(model_path, args) as remote_server:
5445
yield remote_server
5546

@@ -61,11 +52,7 @@ def async_client(server: RemoteOpenAIServer):
6152

6253
@pytest.mark.asyncio(loop_scope="module")
6354
async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
64-
model_name: str, backend: str):
65-
# Skip if backend is PyTorch as it does not support topk logprobs when k > 1
66-
if backend == "pytorch":
67-
pytest.skip("Topk logprobs is not supported")
68-
55+
model_name: str):
6956
messages = [{
7057
"role": "system",
7158
"content": "You are a helpful assistant."
@@ -94,42 +81,3 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
9481
assert logprob_content.bytes is not None
9582
assert logprob_content.top_logprobs is not None
9683
assert len(logprob_content.top_logprobs) == 5
97-
98-
99-
@pytest.mark.asyncio(loop_scope="module")
100-
async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI,
101-
model_name: str, backend: str):
102-
# Skip if backend is TRT because it is tested in test_chat_completion_top5_logprobs
103-
if backend == "trt":
104-
pytest.skip(
105-
"TRT top logprobs is already tested in test_chat_completion_top5_logprobs"
106-
)
107-
108-
messages = [{
109-
"role": "system",
110-
"content": "You are a helpful assistant."
111-
}, {
112-
"role": "user",
113-
"content": "What is the capital of France?"
114-
}]
115-
# Test top_logprobs=1
116-
chat_completion = await async_client.chat.completions.create(
117-
model=model_name,
118-
messages=messages,
119-
max_completion_tokens=10,
120-
temperature=0.0,
121-
logprobs=True,
122-
top_logprobs=1,
123-
extra_body={
124-
"ignore_eos": True,
125-
})
126-
logprobs = chat_completion.choices[0].logprobs
127-
assert logprobs is not None and logprobs.content is not None
128-
assert len(logprobs.content) == 10
129-
for logprob_content in logprobs.content:
130-
assert logprob_content.token is not None
131-
assert logprob_content.logprob is not None
132-
assert logprob_content.bytes is not None
133-
assert logprob_content.top_logprobs is not None
134-
# Check that the top_logprobs contains only one entry
135-
assert len(logprob_content.top_logprobs) == 1

0 commit comments

Comments
 (0)