Skip to content

Commit 5333717

Browse files
mattfiamemilio
authored andcommitted
chore(unit tests): remove network use, update async test (llamastack#3418)
# What does this PR do? update the async detection test for vllm - remove a network access from unit tests - remove direct logging use the idea behind the test is to mock inference w/ a sleep, initiate concurrent inference calls, verify the total execution time is close to the sleep time. in a non-async env the total time would be closer to sleep * num concurrent calls. ## Test Plan ci
1 parent 22db3fc commit 5333717

File tree

1 file changed

+60
-100
lines changed

1 file changed

+60
-100
lines changed

tests/unit/providers/inference/test_remote_vllm.py

Lines changed: 60 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,15 @@
66

77
import asyncio
88
import json
9-
import logging # allow-direct-logging
10-
import threading
119
import time
12-
from http.server import BaseHTTPRequestHandler, HTTPServer
13-
from typing import Any
1410
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
1511

1612
import pytest
1713
from openai.types.chat.chat_completion_chunk import (
1814
ChatCompletionChunk as OpenAIChatCompletionChunk,
1915
)
2016
from openai.types.chat.chat_completion_chunk import (
21-
Choice as OpenAIChoice,
17+
Choice as OpenAIChoiceChunk,
2218
)
2319
from openai.types.chat.chat_completion_chunk import (
2420
ChoiceDelta as OpenAIChoiceDelta,
@@ -35,6 +31,9 @@
3531
ChatCompletionRequest,
3632
ChatCompletionResponseEventType,
3733
CompletionMessage,
34+
OpenAIAssistantMessageParam,
35+
OpenAIChatCompletion,
36+
OpenAIChoice,
3837
SystemMessage,
3938
ToolChoice,
4039
ToolConfig,
@@ -61,41 +60,6 @@
6160
# -v -s --tb=short --disable-warnings
6261

6362

64-
class MockInferenceAdapterWithSleep:
65-
def __init__(self, sleep_time: int, response: dict[str, Any]):
66-
self.httpd = None
67-
68-
class DelayedRequestHandler(BaseHTTPRequestHandler):
69-
# ruff: noqa: N802
70-
def do_POST(self):
71-
time.sleep(sleep_time)
72-
response_body = json.dumps(response).encode("utf-8")
73-
self.send_response(code=200)
74-
self.send_header("Content-Type", "application/json")
75-
self.send_header("Content-Length", len(response_body))
76-
self.end_headers()
77-
self.wfile.write(response_body)
78-
79-
self.request_handler = DelayedRequestHandler
80-
81-
def __enter__(self):
82-
httpd = HTTPServer(("", 0), self.request_handler)
83-
self.httpd = httpd
84-
host, port = httpd.server_address
85-
httpd_thread = threading.Thread(target=httpd.serve_forever)
86-
httpd_thread.daemon = True # stop server if this thread terminates
87-
httpd_thread.start()
88-
89-
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
90-
inference_adapter = VLLMInferenceAdapter(config)
91-
return inference_adapter
92-
93-
def __exit__(self, _exc_type, _exc_value, _traceback):
94-
if self.httpd:
95-
self.httpd.shutdown()
96-
self.httpd.server_close()
97-
98-
9963
@pytest.fixture(scope="module")
10064
def mock_openai_models_list():
10165
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
@@ -201,7 +165,7 @@ async def test_tool_call_delta_empty_tool_call_buf():
201165

202166
async def mock_stream():
203167
delta = OpenAIChoiceDelta(content="", tool_calls=None)
204-
choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)]
168+
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
205169
mock_chunk = OpenAIChatCompletionChunk(
206170
id="chunk-1",
207171
created=1,
@@ -227,7 +191,7 @@ async def mock_stream():
227191
model="foo",
228192
object="chat.completion.chunk",
229193
choices=[
230-
OpenAIChoice(
194+
OpenAIChoiceChunk(
231195
delta=OpenAIChoiceDelta(
232196
content="",
233197
tool_calls=[
@@ -252,7 +216,7 @@ async def mock_stream():
252216
model="foo",
253217
object="chat.completion.chunk",
254218
choices=[
255-
OpenAIChoice(
219+
OpenAIChoiceChunk(
256220
delta=OpenAIChoiceDelta(
257221
content="",
258222
tool_calls=[
@@ -277,7 +241,9 @@ async def mock_stream():
277241
model="foo",
278242
object="chat.completion.chunk",
279243
choices=[
280-
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
244+
OpenAIChoiceChunk(
245+
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
246+
)
281247
],
282248
)
283249
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@@ -301,7 +267,7 @@ async def mock_stream():
301267
model="foo",
302268
object="chat.completion.chunk",
303269
choices=[
304-
OpenAIChoice(
270+
OpenAIChoiceChunk(
305271
delta=OpenAIChoiceDelta(
306272
content="",
307273
tool_calls=[
@@ -326,7 +292,7 @@ async def mock_stream():
326292
model="foo",
327293
object="chat.completion.chunk",
328294
choices=[
329-
OpenAIChoice(
295+
OpenAIChoiceChunk(
330296
delta=OpenAIChoiceDelta(
331297
content="",
332298
tool_calls=[
@@ -351,7 +317,9 @@ async def mock_stream():
351317
model="foo",
352318
object="chat.completion.chunk",
353319
choices=[
354-
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
320+
OpenAIChoiceChunk(
321+
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
322+
)
355323
],
356324
)
357325
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
@@ -395,59 +363,6 @@ async def mock_stream():
395363
assert chunks[0].event.event_type.value == "start"
396364

397365

398-
@pytest.mark.allow_network
399-
def test_chat_completion_doesnt_block_event_loop(caplog):
400-
loop = asyncio.new_event_loop()
401-
loop.set_debug(True)
402-
caplog.set_level(logging.WARNING)
403-
404-
# Log when event loop is blocked for more than 200ms
405-
loop.slow_callback_duration = 0.5
406-
# Sleep for 500ms in our delayed http response
407-
sleep_time = 0.5
408-
409-
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
410-
mock_response = {
411-
"id": "chatcmpl-abc123",
412-
"object": "chat.completion",
413-
"created": 1,
414-
"modle": "mock-model",
415-
"choices": [
416-
{
417-
"message": {"content": ""},
418-
"logprobs": None,
419-
"finish_reason": "stop",
420-
"index": 0,
421-
}
422-
],
423-
}
424-
425-
async def do_chat_completion():
426-
await inference_adapter.chat_completion(
427-
"mock-model",
428-
[],
429-
stream=False,
430-
tools=None,
431-
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
432-
)
433-
434-
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
435-
inference_adapter.model_store = AsyncMock()
436-
inference_adapter.model_store.get_model.return_value = mock_model
437-
loop.run_until_complete(inference_adapter.initialize())
438-
439-
# Clear the logs so far and run the actual chat completion we care about
440-
caplog.clear()
441-
loop.run_until_complete(do_chat_completion())
442-
443-
# Ensure we don't have any asyncio warnings in the captured log
444-
# records from our chat completion call. A message gets logged
445-
# here any time we exceed the slow_callback_duration configured
446-
# above.
447-
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
448-
assert not asyncio_warnings
449-
450-
451366
async def test_get_params_empty_tools(vllm_inference_adapter):
452367
request = ChatCompletionRequest(
453368
tools=[],
@@ -696,3 +611,48 @@ async def mock_list():
696611
assert "Health check failed: Connection failed" in health_response["message"]
697612

698613
mock_models.list.assert_called_once()
614+
615+
616+
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
617+
"""
618+
Verify that openai_chat_completion is async and doesn't block the event loop.
619+
620+
To do this we mock the underlying inference with a sleep, start multiple
621+
inference calls in parallel, and ensure the total time taken is less
622+
than the sum of the individual sleep times.
623+
"""
624+
sleep_time = 0.5
625+
626+
async def mock_create(*args, **kwargs):
627+
await asyncio.sleep(sleep_time)
628+
return OpenAIChatCompletion(
629+
id="chatcmpl-abc123",
630+
created=1,
631+
model="mock-model",
632+
choices=[
633+
OpenAIChoice(
634+
message=OpenAIAssistantMessageParam(
635+
content="nothing interesting",
636+
),
637+
finish_reason="stop",
638+
index=0,
639+
)
640+
],
641+
)
642+
643+
async def do_inference():
644+
await vllm_inference_adapter.openai_chat_completion(
645+
"mock-model", messages=["one fish", "two fish"], stream=False
646+
)
647+
648+
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
649+
mock_client = MagicMock()
650+
mock_client.chat.completions.create = AsyncMock(side_effect=mock_create)
651+
mock_create_client.return_value = mock_client
652+
653+
start_time = time.time()
654+
await asyncio.gather(do_inference(), do_inference(), do_inference(), do_inference())
655+
total_time = time.time() - start_time
656+
657+
assert mock_create_client.call_count == 4 # no cheating
658+
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"

0 commit comments

Comments
 (0)