Skip to content

Commit 79fc2f4

Browse files
authored
[None][chore] Enhance trtllm-serve example test (NVIDIA#6604)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent b7347ce commit 79fc2f4

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

examples/serve/openai_completion_client_json_schema.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
### :title OpenAI Completion Client with JSON Schema
22

3+
# This example requires to specify `guided_decoding_backend` as
4+
# `xgrammar` or `llguidance` in the extra_llm_api_options.yaml file.
5+
import json
6+
37
from openai import OpenAI
48

59
client = OpenAI(
@@ -18,7 +22,6 @@
1822
"content":
1923
f"Give me the information of the biggest city of China in the JSON format.",
2024
}],
21-
max_tokens=100,
2225
temperature=0,
2326
response_format={
2427
"type": "json",
@@ -39,4 +42,11 @@
3942
}
4043
},
4144
)
42-
print(response.choices[0].message.content)
45+
46+
content = response.choices[0].message.content
47+
try:
48+
response_json = json.loads(content)
49+
assert "name" in response_json and "population" in response_json
50+
print(content)
51+
except json.JSONDecodeError:
52+
print("Failed to decode JSON response")

tests/unittest/llmapi/apps/_test_trtllm_serve_example.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import json
12
import os
23
import subprocess
34
import sys
5+
import tempfile
46

57
import pytest
8+
import yaml
69

710
from .openai_server import RemoteOpenAIServer
811

@@ -16,10 +19,26 @@ def model_name():
1619

1720

1821
@pytest.fixture(scope="module")
19-
def server(model_name: str):
22+
def temp_extra_llm_api_options_file():
23+
temp_dir = tempfile.gettempdir()
24+
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
25+
try:
26+
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
27+
with open(temp_file_path, 'w') as f:
28+
yaml.dump(extra_llm_api_options_dict, f)
29+
30+
yield temp_file_path
31+
finally:
32+
if os.path.exists(temp_file_path):
33+
os.remove(temp_file_path)
34+
35+
36+
@pytest.fixture(scope="module")
37+
def server(model_name: str, temp_extra_llm_api_options_file: str):
2038
model_path = get_model_path(model_name)
2139
# fix port to facilitate concise trtllm-serve examples
22-
with RemoteOpenAIServer(model_path, port=8000) as remote_server:
40+
args = ["--extra_llm_api_options", temp_extra_llm_api_options_file]
41+
with RemoteOpenAIServer(model_path, args, port=8000) as remote_server:
2342
yield remote_server
2443

2544

@@ -40,8 +59,19 @@ def test_trtllm_serve_examples(exe: str, script: str,
4059
server: RemoteOpenAIServer, example_root: str):
4160
client_script = os.path.join(example_root, script)
4261
# CalledProcessError will be raised if any errors occur
43-
subprocess.run([exe, client_script],
44-
stdout=subprocess.PIPE,
45-
stderr=subprocess.PIPE,
46-
text=True,
47-
check=True)
62+
result = subprocess.run([exe, client_script],
63+
stdout=subprocess.PIPE,
64+
stderr=subprocess.PIPE,
65+
text=True,
66+
check=True)
67+
if script.startswith("curl"):
68+
# For curl scripts, we expect a JSON response
69+
result_stdout = result.stdout.strip()
70+
try:
71+
data = json.loads(result_stdout)
72+
assert "code" not in data or data[
73+
"code"] == 200, f"Unexpected response: {data}"
74+
except json.JSONDecodeError as e:
75+
pytest.fail(
76+
f"Failed to parse JSON response from {script}: {e}\nStdout: {result_stdout}\nStderr: {result.stderr}"
77+
)

0 commit comments

Comments
 (0)