Skip to content

Commit 70f40f5

Browse files
committed
test: add models.list() recording/replay support
o Add patching for OpenAI AsyncModels.list method to inference recorder o Create AsyncIterableModelsWrapper that supports both usage patterns: * Direct async iteration: async for m in client.models.list() * Await then iterate: res = await client.models.list(); async for m in res o Update streaming detection to handle AsyncPage objects from models.list o Preserve all existing recording/replay functionality for other endpoints Signed-off-by: Derek Higgins <[email protected]>
1 parent 06566d0 commit 70f40f5

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

llama_stack/testing/inference_recorder.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818
from typing import Any, Literal, cast
1919

20+
from openai.pagination import AsyncPage
2021
from openai.types.chat import ChatCompletion, ChatCompletionChunk
2122

2223
from llama_stack.log import get_logger
@@ -296,7 +297,8 @@ async def replay_stream():
296297
}
297298

298299
# Determine if this is a streaming request based on request parameters
299-
is_streaming = body.get("stream", False)
300+
# or if the response is an AsyncPage (like models.list returns)
301+
is_streaming = body.get("stream", False) or isinstance(response, AsyncPage)
300302

301303
if is_streaming:
302304
# For streaming responses, we need to collect all chunks immediately before yielding
@@ -332,9 +334,11 @@ def patch_inference_clients():
332334
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
333335
from openai.resources.completions import AsyncCompletions
334336
from openai.resources.embeddings import AsyncEmbeddings
337+
from openai.resources.models import AsyncModels
335338

336339
# Store original methods for both OpenAI and Ollama clients
337340
_original_methods = {
341+
"models_list": AsyncModels.list,
338342
"chat_completions_create": AsyncChatCompletions.create,
339343
"completions_create": AsyncCompletions.create,
340344
"embeddings_create": AsyncEmbeddings.create,
@@ -346,7 +350,38 @@ def patch_inference_clients():
346350
"ollama_list": OllamaAsyncClient.list,
347351
}
348352

349-
# Create patched methods for OpenAI client
353+
# Special handling for models.list which needs to return something directly async-iterable
354+
# Direct iteration: async for m in client.models.list()
355+
# Await then iterate: res = await client.models.list(); async for m in res
356+
def patched_models_list(self, *args, **kwargs):
357+
class AsyncIterableModelsWrapper:
358+
def __init__(self, original_method, client_self, args, kwargs):
359+
self.original_method = original_method
360+
self.client_self = client_self
361+
self.args = args
362+
self.kwargs = kwargs
363+
self._result = None
364+
365+
def __aiter__(self):
366+
return self._async_iter()
367+
368+
async def _async_iter(self):
369+
# Get the result from the patched method
370+
result = await _patched_inference_method(
371+
self.original_method, self.client_self, "openai", "/v1/models", *self.args, **self.kwargs
372+
)
373+
async for item in result:
374+
yield item
375+
376+
def __await__(self):
377+
# When awaited, return self (since we're already async-iterable)
378+
async def _return_self():
379+
return self
380+
381+
return _return_self().__await__()
382+
383+
return AsyncIterableModelsWrapper(_original_methods["models_list"], self, args, kwargs)
384+
350385
async def patched_chat_completions_create(self, *args, **kwargs):
351386
return await _patched_inference_method(
352387
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
@@ -363,6 +398,7 @@ async def patched_embeddings_create(self, *args, **kwargs):
363398
)
364399

365400
# Apply OpenAI patches
401+
AsyncModels.list = patched_models_list
366402
AsyncChatCompletions.create = patched_chat_completions_create
367403
AsyncCompletions.create = patched_completions_create
368404
AsyncEmbeddings.create = patched_embeddings_create
@@ -419,8 +455,10 @@ def unpatch_inference_clients():
419455
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
420456
from openai.resources.completions import AsyncCompletions
421457
from openai.resources.embeddings import AsyncEmbeddings
458+
from openai.resources.models import AsyncModels
422459

423460
# Restore OpenAI client methods
461+
AsyncModels.list = _original_methods["models_list"]
424462
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
425463
AsyncCompletions.create = _original_methods["completions_create"]
426464
AsyncEmbeddings.create = _original_methods["embeddings_create"]

0 commit comments

Comments
 (0)