Skip to content

Commit c1378df

Browse files
mattfiamemilio
authored andcommitted
chore(recorder, tests): add test for openai /v1/models (llamastack#3426)
# What does this PR do? - [x] adds a test for the recorder's handling of /v1/models - [x] adds a fix for /v1/models handling ## Test Plan ci
1 parent b96eed9 commit c1378df

File tree

2 files changed

+79
-32
lines changed

2 files changed

+79
-32
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations # for forward references
88

99
import hashlib
10+
import inspect
1011
import json
1112
import os
1213
from collections.abc import Generator
@@ -198,16 +199,11 @@ def _extract_model_identifiers():
198199
199200
Supported endpoints:
200201
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
201-
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
202+
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
202203
Returns a list of unique identifiers or None if structure doesn't match.
203204
"""
204-
body = response["body"]
205-
if endpoint == "/api/tags":
206-
items = body.get("models")
207-
idents = [m.model for m in items]
208-
else:
209-
items = body.get("data")
210-
idents = [m.id for m in items]
205+
items = response["body"]
206+
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
211207
return sorted(set(idents))
212208

213209
identifiers = _extract_model_identifiers()
@@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
219215
seen: dict[str, dict[str, Any]] = {}
220216
for rec in records:
221217
body = rec["response"]["body"]
222-
if endpoint == "/api/tags":
223-
items = body.models
224-
elif endpoint == "/v1/models":
225-
items = body.data
226-
else:
227-
items = []
228-
229-
for m in items:
230-
if endpoint == "/v1/models":
218+
if endpoint == "/v1/models":
219+
for m in body:
231220
key = m.id
232-
else:
221+
seen[key] = m
222+
elif endpoint == "/api/tags":
223+
for m in body.models:
233224
key = m.model
234-
seen[key] = m
225+
seen[key] = m
235226

236227
ordered = [seen[k] for k in sorted(seen.keys())]
237228
canonical = records[0]
238229
canonical_req = canonical.get("request", {})
239230
if isinstance(canonical_req, dict):
240231
canonical_req["endpoint"] = endpoint
241-
if endpoint == "/v1/models":
242-
body = {"data": ordered, "object": "list"}
243-
else:
232+
body = ordered
233+
if endpoint == "/api/tags":
244234
from ollama import ListResponse
245235

246236
body = ListResponse(models=ordered)
@@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
252242

253243
if _current_mode == InferenceMode.LIVE or _current_storage is None:
254244
# Normal operation
255-
return await original_method(self, *args, **kwargs)
245+
if inspect.iscoroutinefunction(original_method):
246+
return await original_method(self, *args, **kwargs)
247+
else:
248+
return original_method(self, *args, **kwargs)
256249

257250
# Get base URL based on client type
258251
if client_type == "openai":
@@ -300,7 +293,14 @@ async def replay_stream():
300293
)
301294

302295
elif _current_mode == InferenceMode.RECORD:
303-
response = await original_method(self, *args, **kwargs)
296+
if inspect.iscoroutinefunction(original_method):
297+
response = await original_method(self, *args, **kwargs)
298+
else:
299+
response = original_method(self, *args, **kwargs)
300+
301+
# we want to store the result of the iterator, not the iterator itself
302+
if endpoint == "/v1/models":
303+
response = [m async for m in response]
304304

305305
request_data = {
306306
"method": method,
@@ -380,10 +380,14 @@ async def patched_embeddings_create(self, *args, **kwargs):
380380
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
381381
)
382382

383-
async def patched_models_list(self, *args, **kwargs):
384-
return await _patched_inference_method(
385-
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
386-
)
383+
def patched_models_list(self, *args, **kwargs):
384+
async def _iter():
385+
for item in await _patched_inference_method(
386+
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
387+
):
388+
yield item
389+
390+
return _iter()
387391

388392
# Apply OpenAI patches
389393
AsyncChatCompletions.create = patched_chat_completions_create

tests/unit/distribution/test_inference_recordings.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import tempfile
88
from pathlib import Path
9-
from unittest.mock import patch
9+
from unittest.mock import AsyncMock, Mock, patch
1010

1111
import pytest
1212
from openai import AsyncOpenAI
13+
from openai.types.model import Model as OpenAIModel
1314

1415
# Import the real Pydantic response types instead of using Mocks
1516
from llama_stack.apis.inference import (
@@ -158,7 +159,9 @@ async def mock_create(*args, **kwargs):
158159
return real_openai_chat_response
159160

160161
temp_storage_dir = temp_storage_dir / "test_recording_mode"
161-
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
162+
with patch(
163+
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
164+
):
162165
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
163166
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
164167

@@ -184,7 +187,9 @@ async def mock_create(*args, **kwargs):
184187

185188
temp_storage_dir = temp_storage_dir / "test_replay_mode"
186189
# First, record a response
187-
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
190+
with patch(
191+
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
192+
):
188193
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
189194
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
190195

@@ -213,6 +218,42 @@ async def mock_create(*args, **kwargs):
213218
# Verify the original method was NOT called
214219
mock_create_patch.assert_not_called()
215220

221+
async def test_replay_mode_models(self, temp_storage_dir):
222+
"""Test that replay mode returns stored responses without making real model listing calls."""
223+
224+
async def _async_iterator(models):
225+
for model in models:
226+
yield model
227+
228+
models = [
229+
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
230+
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
231+
]
232+
233+
expected_ids = {m.id for m in models}
234+
235+
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
236+
237+
# baseline - mock works without recording
238+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
239+
client.models._get_api_list = Mock(return_value=_async_iterator(models))
240+
assert {m.id async for m in client.models.list()} == expected_ids
241+
client.models._get_api_list.assert_called_once()
242+
243+
# record the call
244+
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
245+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
246+
client.models._get_api_list = Mock(return_value=_async_iterator(models))
247+
assert {m.id async for m in client.models.list()} == expected_ids
248+
client.models._get_api_list.assert_called_once()
249+
250+
# replay the call
251+
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
252+
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
253+
client.models._get_api_list = Mock(return_value=_async_iterator(models))
254+
assert {m.id async for m in client.models.list()} == expected_ids
255+
client.models._get_api_list.assert_not_called()
256+
216257
async def test_replay_missing_recording(self, temp_storage_dir):
217258
"""Test that replay mode fails when no recording is found."""
218259
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
@@ -233,7 +274,9 @@ async def mock_create(*args, **kwargs):
233274

234275
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
235276
# Record
236-
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
277+
with patch(
278+
"openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create
279+
):
237280
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
238281
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
239282

0 commit comments

Comments
 (0)