66
77import asyncio
88import json
9- import logging # allow-direct-logging
10- import threading
119import time
12- from http .server import BaseHTTPRequestHandler , HTTPServer
13- from typing import Any
1410from unittest .mock import AsyncMock , MagicMock , PropertyMock , patch
1511
1612import pytest
1713from openai .types .chat .chat_completion_chunk import (
1814 ChatCompletionChunk as OpenAIChatCompletionChunk ,
1915)
2016from openai .types .chat .chat_completion_chunk import (
21- Choice as OpenAIChoice ,
17+ Choice as OpenAIChoiceChunk ,
2218)
2319from openai .types .chat .chat_completion_chunk import (
2420 ChoiceDelta as OpenAIChoiceDelta ,
3531 ChatCompletionRequest ,
3632 ChatCompletionResponseEventType ,
3733 CompletionMessage ,
34+ OpenAIAssistantMessageParam ,
35+ OpenAIChatCompletion ,
36+ OpenAIChoice ,
3837 SystemMessage ,
3938 ToolChoice ,
4039 ToolConfig ,
6160# -v -s --tb=short --disable-warnings
6261
6362
64- class MockInferenceAdapterWithSleep :
65- def __init__ (self , sleep_time : int , response : dict [str , Any ]):
66- self .httpd = None
67-
68- class DelayedRequestHandler (BaseHTTPRequestHandler ):
69- # ruff: noqa: N802
70- def do_POST (self ):
71- time .sleep (sleep_time )
72- response_body = json .dumps (response ).encode ("utf-8" )
73- self .send_response (code = 200 )
74- self .send_header ("Content-Type" , "application/json" )
75- self .send_header ("Content-Length" , len (response_body ))
76- self .end_headers ()
77- self .wfile .write (response_body )
78-
79- self .request_handler = DelayedRequestHandler
80-
81- def __enter__ (self ):
82- httpd = HTTPServer (("" , 0 ), self .request_handler )
83- self .httpd = httpd
84- host , port = httpd .server_address
85- httpd_thread = threading .Thread (target = httpd .serve_forever )
86- httpd_thread .daemon = True # stop server if this thread terminates
87- httpd_thread .start ()
88-
89- config = VLLMInferenceAdapterConfig (url = f"http://{ host } :{ port } " )
90- inference_adapter = VLLMInferenceAdapter (config )
91- return inference_adapter
92-
93- def __exit__ (self , _exc_type , _exc_value , _traceback ):
94- if self .httpd :
95- self .httpd .shutdown ()
96- self .httpd .server_close ()
97-
98-
9963@pytest .fixture (scope = "module" )
10064def mock_openai_models_list ():
10165 with patch ("openai.resources.models.AsyncModels.list" , new_callable = AsyncMock ) as mock_list :
@@ -201,7 +165,7 @@ async def test_tool_call_delta_empty_tool_call_buf():
201165
202166 async def mock_stream ():
203167 delta = OpenAIChoiceDelta (content = "" , tool_calls = None )
204- choices = [OpenAIChoice (delta = delta , finish_reason = "stop" , index = 0 )]
168+ choices = [OpenAIChoiceChunk (delta = delta , finish_reason = "stop" , index = 0 )]
205169 mock_chunk = OpenAIChatCompletionChunk (
206170 id = "chunk-1" ,
207171 created = 1 ,
@@ -227,7 +191,7 @@ async def mock_stream():
227191 model = "foo" ,
228192 object = "chat.completion.chunk" ,
229193 choices = [
230- OpenAIChoice (
194+ OpenAIChoiceChunk (
231195 delta = OpenAIChoiceDelta (
232196 content = "" ,
233197 tool_calls = [
@@ -252,7 +216,7 @@ async def mock_stream():
252216 model = "foo" ,
253217 object = "chat.completion.chunk" ,
254218 choices = [
255- OpenAIChoice (
219+ OpenAIChoiceChunk (
256220 delta = OpenAIChoiceDelta (
257221 content = "" ,
258222 tool_calls = [
@@ -277,7 +241,9 @@ async def mock_stream():
277241 model = "foo" ,
278242 object = "chat.completion.chunk" ,
279243 choices = [
280- OpenAIChoice (delta = OpenAIChoiceDelta (content = "" , tool_calls = None ), finish_reason = "tool_calls" , index = 0 )
244+ OpenAIChoiceChunk (
245+ delta = OpenAIChoiceDelta (content = "" , tool_calls = None ), finish_reason = "tool_calls" , index = 0
246+ )
281247 ],
282248 )
283249 for chunk in [mock_chunk_1 , mock_chunk_2 , mock_chunk_3 ]:
@@ -301,7 +267,7 @@ async def mock_stream():
301267 model = "foo" ,
302268 object = "chat.completion.chunk" ,
303269 choices = [
304- OpenAIChoice (
270+ OpenAIChoiceChunk (
305271 delta = OpenAIChoiceDelta (
306272 content = "" ,
307273 tool_calls = [
@@ -326,7 +292,7 @@ async def mock_stream():
326292 model = "foo" ,
327293 object = "chat.completion.chunk" ,
328294 choices = [
329- OpenAIChoice (
295+ OpenAIChoiceChunk (
330296 delta = OpenAIChoiceDelta (
331297 content = "" ,
332298 tool_calls = [
@@ -351,7 +317,9 @@ async def mock_stream():
351317 model = "foo" ,
352318 object = "chat.completion.chunk" ,
353319 choices = [
354- OpenAIChoice (delta = OpenAIChoiceDelta (content = "" , tool_calls = None ), finish_reason = "tool_calls" , index = 0 )
320+ OpenAIChoiceChunk (
321+ delta = OpenAIChoiceDelta (content = "" , tool_calls = None ), finish_reason = "tool_calls" , index = 0
322+ )
355323 ],
356324 )
357325 for chunk in [mock_chunk_1 , mock_chunk_2 , mock_chunk_3 ]:
@@ -395,59 +363,6 @@ async def mock_stream():
395363 assert chunks [0 ].event .event_type .value == "start"
396364
397365
398- @pytest .mark .allow_network
399- def test_chat_completion_doesnt_block_event_loop (caplog ):
400- loop = asyncio .new_event_loop ()
401- loop .set_debug (True )
402- caplog .set_level (logging .WARNING )
403-
404- # Log when event loop is blocked for more than 200ms
405- loop .slow_callback_duration = 0.5
406- # Sleep for 500ms in our delayed http response
407- sleep_time = 0.5
408-
409- mock_model = Model (identifier = "mock-model" , provider_resource_id = "mock-model" , provider_id = "vllm-inference" )
410- mock_response = {
411- "id" : "chatcmpl-abc123" ,
412- "object" : "chat.completion" ,
413- "created" : 1 ,
414- "modle" : "mock-model" ,
415- "choices" : [
416- {
417- "message" : {"content" : "" },
418- "logprobs" : None ,
419- "finish_reason" : "stop" ,
420- "index" : 0 ,
421- }
422- ],
423- }
424-
425- async def do_chat_completion ():
426- await inference_adapter .chat_completion (
427- "mock-model" ,
428- [],
429- stream = False ,
430- tools = None ,
431- tool_config = ToolConfig (tool_choice = ToolChoice .auto ),
432- )
433-
434- with MockInferenceAdapterWithSleep (sleep_time , mock_response ) as inference_adapter :
435- inference_adapter .model_store = AsyncMock ()
436- inference_adapter .model_store .get_model .return_value = mock_model
437- loop .run_until_complete (inference_adapter .initialize ())
438-
439- # Clear the logs so far and run the actual chat completion we care about
440- caplog .clear ()
441- loop .run_until_complete (do_chat_completion ())
442-
443- # Ensure we don't have any asyncio warnings in the captured log
444- # records from our chat completion call. A message gets logged
445- # here any time we exceed the slow_callback_duration configured
446- # above.
447- asyncio_warnings = [record .message for record in caplog .records if record .name == "asyncio" ]
448- assert not asyncio_warnings
449-
450-
451366async def test_get_params_empty_tools (vllm_inference_adapter ):
452367 request = ChatCompletionRequest (
453368 tools = [],
@@ -696,3 +611,48 @@ async def mock_list():
696611 assert "Health check failed: Connection failed" in health_response ["message" ]
697612
698613 mock_models .list .assert_called_once ()
614+
615+
616+ async def test_openai_chat_completion_is_async (vllm_inference_adapter ):
617+ """
618+ Verify that openai_chat_completion is async and doesn't block the event loop.
619+
620+ To do this we mock the underlying inference with a sleep, start multiple
621+ inference calls in parallel, and ensure the total time taken is less
622+ than the sum of the individual sleep times.
623+ """
624+ sleep_time = 0.5
625+
626+ async def mock_create (* args , ** kwargs ):
627+ await asyncio .sleep (sleep_time )
628+ return OpenAIChatCompletion (
629+ id = "chatcmpl-abc123" ,
630+ created = 1 ,
631+ model = "mock-model" ,
632+ choices = [
633+ OpenAIChoice (
634+ message = OpenAIAssistantMessageParam (
635+ content = "nothing interesting" ,
636+ ),
637+ finish_reason = "stop" ,
638+ index = 0 ,
639+ )
640+ ],
641+ )
642+
643+ async def do_inference ():
644+ await vllm_inference_adapter .openai_chat_completion (
645+ "mock-model" , messages = ["one fish" , "two fish" ], stream = False
646+ )
647+
648+ with patch .object (VLLMInferenceAdapter , "client" , new_callable = PropertyMock ) as mock_create_client :
649+ mock_client = MagicMock ()
650+ mock_client .chat .completions .create = AsyncMock (side_effect = mock_create )
651+ mock_create_client .return_value = mock_client
652+
653+ start_time = time .time ()
654+ await asyncio .gather (do_inference (), do_inference (), do_inference (), do_inference ())
655+ total_time = time .time () - start_time
656+
657+ assert mock_create_client .call_count == 4 # no cheating
658+ assert total_time < (sleep_time * 2 ), f"Total time taken: { total_time } s exceeded expected max"
0 commit comments