Skip to content

Commit 338a78c

Browse files
ixlmarSimengLiu-nv
authored andcommitted
[TRTLLM-8598][feat] enable n > 1 in OpenAI API with PyTorch backend (#8951)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent dbb58ba commit 338a78c

File tree

5 files changed

+102
-20
lines changed

5 files changed

+102
-20
lines changed

.pre-commit-config.yaml

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ repos:
6666
additional_dependencies:
6767
- tomli
6868
# add ignore words list
69-
args: ["-L", "Mor,ans,thirdparty"]
69+
args: ["-L", "Mor,ans,thirdparty", "--skip", "security_scanning/*"]
7070
- repo: https://github.com/astral-sh/ruff-pre-commit
7171
rev: v0.9.4
7272
hooks:

tensorrt_llm/serve/chat_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,3 @@ def parse_chat_messages_coroutines(
197197

198198
return conversation, mm_data_tracker.retrieve_all_async(
199199
), mm_placeholder_counts
200-
201-
202-
def check_multiple_response(n: int, backend: Optional[str]):
203-
if n > 1 and backend == "pytorch":
204-
raise ValueError(
205-
"Multiple response is not supported in PyTorch workflow")

tensorrt_llm/serve/openai_server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
from tensorrt_llm.llmapi.llm import RequestOutput
3232
from tensorrt_llm.logger import logger
3333
from tensorrt_llm.metrics.collector import MetricsCollector
34-
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
35-
parse_chat_messages_coroutines)
34+
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
3635
from tensorrt_llm.serve.metadata_server import create_metadata_server
3736
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
3837
ChatCompletionResponse,
@@ -417,7 +416,6 @@ async def create_chat_response(
417416
return chat_response
418417

419418
try:
420-
check_multiple_response(request.n, self.llm.args.backend)
421419
conversation: List[ConversationMessage] = []
422420
tool_dicts = None if request.tools is None else [
423421
tool.model_dump() for tool in request.tools
@@ -524,7 +522,6 @@ async def create_mm_embedding_response(promise: RequestOutput):
524522
)
525523

526524
try:
527-
check_multiple_response(request.n, self.llm.args.backend)
528525
conversation: List[ConversationMessage] = []
529526
tool_dicts = None if request.tools is None else [
530527
tool.model_dump() for tool in request.tools
@@ -651,7 +648,6 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
651648
yield "data: [DONE]\n\n"
652649

653650
try:
654-
check_multiple_response(request.n, self.llm.args.backend)
655651
if isinstance(request.prompt, str) or \
656652
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
657653
prompts = [request.prompt]

tests/unittest/llmapi/apps/_test_openai_chat.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
6868
temp_extra_llm_api_options_file: str, num_postprocess_workers: int):
6969
model_path = get_model_path(model_name)
7070
args = ["--backend", f"{backend}"]
71+
args.extend(["--kv_cache_free_gpu_memory_fraction",
72+
"0.2"]) # for co-existence with other servers
7173
if backend == "trt":
7274
args.extend(["--max_beam_width", "4"])
7375
if extra_llm_api_options:
@@ -78,11 +80,34 @@ def server(model_name: str, backend: str, extra_llm_api_options: bool,
7880
yield remote_server
7981

8082

83+
@pytest.fixture(scope="module")
84+
def server_with_beam_search(model_name: str, backend: str,
85+
extra_llm_api_options: bool,
86+
temp_extra_llm_api_options_file: str,
87+
num_postprocess_workers: int):
88+
model_path = get_model_path(model_name)
89+
args = ["--backend", f"{backend}"]
90+
args.extend(["--kv_cache_free_gpu_memory_fraction",
91+
"0.2"]) # for co-existence with other servers
92+
args.extend(["--max_beam_width", "2"])
93+
if extra_llm_api_options:
94+
args.extend(
95+
["--extra_llm_api_options", temp_extra_llm_api_options_file])
96+
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
97+
with RemoteOpenAIServer(model_path, args) as remote_server:
98+
yield remote_server
99+
100+
81101
@pytest.fixture(scope="module")
82102
def client(server: RemoteOpenAIServer):
83103
return server.get_client()
84104

85105

106+
@pytest.fixture(scope="module")
107+
def client_with_beam_search(server_with_beam_search: RemoteOpenAIServer):
108+
return server_with_beam_search.get_client()
109+
110+
86111
@pytest.fixture(scope="module")
87112
def async_client(server: RemoteOpenAIServer):
88113
return server.get_async_client()
@@ -180,7 +205,33 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
180205
backend: str):
181206
if backend == "pytorch":
182207
pytest.skip(
183-
"Multiple responses are not supported in PyTorch backend yet")
208+
"'n' not allowed with temperature=0 unless TLLM_ALLOW_N_GREEDY_DECODING=1"
209+
)
210+
messages = [{
211+
"role": "system",
212+
"content": "you are a helpful assistant"
213+
}, {
214+
"role": "user",
215+
"content": "what is 1+1?"
216+
}]
217+
# test n and best_of
218+
chat_completion = client.chat.completions.create(
219+
model=model_name,
220+
messages=messages,
221+
max_completion_tokens=10,
222+
n=2,
223+
temperature=0.0,
224+
extra_body=dict(best_of=4),
225+
)
226+
assert len(chat_completion.choices) == 2
227+
228+
229+
def test_multiple_responses_and_beam_search(client: openai.OpenAI,
230+
model_name: str, backend: str):
231+
if backend == "pytorch":
232+
pytest.skip(
233+
"Mixing beam search and regular requests is not supported in PyTorch backend"
234+
)
184235

185236
messages = [{
186237
"role": "system",
@@ -202,6 +253,7 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
202253
assert chat_completion.choices[
203254
0].message.content != chat_completion.choices[
204255
1].message.content, "beam search should be different"
256+
205257
# test n and best_of
206258
chat_completion = client.chat.completions.create(
207259
model=model_name,
@@ -214,6 +266,30 @@ def test_multiple_responses(client: openai.OpenAI, model_name: str,
214266
assert len(chat_completion.choices) == 2
215267

216268

269+
def test_multiple_responses_with_beam_search(
270+
client_with_beam_search: openai.OpenAI, model_name: str):
271+
messages = [{
272+
"role": "system",
273+
"content": "you are a helpful assistant"
274+
}, {
275+
"role": "user",
276+
"content": "what is 1+1?"
277+
}]
278+
# test beam search
279+
chat_completion = client_with_beam_search.chat.completions.create(
280+
model=model_name,
281+
messages=messages,
282+
max_completion_tokens=10,
283+
n=2,
284+
temperature=0.0,
285+
extra_body=dict(use_beam_search=True),
286+
)
287+
assert len(chat_completion.choices) == 2
288+
assert chat_completion.choices[
289+
0].message.content != chat_completion.choices[
290+
1].message.content, "beam search should be different"
291+
292+
217293
@pytest.mark.asyncio(loop_scope="module")
218294
async def test_chat_streaming(async_client: openai.AsyncOpenAI,
219295
model_name: str):

tests/unittest/llmapi/apps/_test_openai_completions.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,21 @@ def num_postprocess_workers(request):
3333
def server(model_name: str, backend: str, num_postprocess_workers: int):
3434
model_path = get_model_path(model_name)
3535
args = ["--backend", f"{backend}"]
36-
if backend == "trt":
37-
args.extend(["--max_beam_width", "4"])
36+
args.extend(["--kv_cache_free_gpu_memory_fraction",
37+
"0.2"]) # for co-existence with other servers
38+
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
39+
with RemoteOpenAIServer(model_path, args) as remote_server:
40+
yield remote_server
41+
42+
43+
@pytest.fixture(scope="module")
44+
def server_with_beam_search(model_name: str, backend: str,
45+
num_postprocess_workers: int):
46+
model_path = get_model_path(model_name)
47+
args = ["--backend", f"{backend}"]
48+
args.extend(["--kv_cache_free_gpu_memory_fraction",
49+
"0.2"]) # for co-existence with other servers
50+
args.extend(["--max_beam_width", "2"])
3851
args.extend(["--num_postprocess_workers", f"{num_postprocess_workers}"])
3952
with RemoteOpenAIServer(model_path, args) as remote_server:
4053
yield remote_server
@@ -50,6 +63,11 @@ def async_client(server: RemoteOpenAIServer):
5063
return server.get_async_client()
5164

5265

66+
@pytest.fixture(scope="module")
67+
def async_client_with_beam_search(server_with_beam_search: RemoteOpenAIServer):
68+
return server_with_beam_search.get_async_client()
69+
70+
5371
def test_single_completion(client: openai.OpenAI, model_name):
5472
completion = client.completions.create(
5573
model=model_name,
@@ -146,12 +164,10 @@ async def test_batch_completions(async_client: openai.AsyncOpenAI, model_name,
146164
@pytest.mark.asyncio(loop_scope="module")
147165
@pytest.mark.parametrize("prompts",
148166
[["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2])
149-
async def test_batch_completions_beam_search(async_client: openai.AsyncOpenAI,
150-
model_name, prompts, backend):
167+
async def test_batch_completions_beam_search(
168+
async_client_with_beam_search: openai.AsyncOpenAI, model_name, prompts):
151169
# test beam search
152-
if backend == 'pytorch':
153-
pytest.skip("Beam search is not supported in PyTorch backend yet")
154-
batch = await async_client.completions.create(
170+
batch = await async_client_with_beam_search.completions.create(
155171
model=model_name,
156172
prompt=prompts,
157173
n=2,

0 commit comments

Comments
 (0)