Skip to content

Commit 9614828

Browse files
committed
Add support for AsyncPage
Signed-off-by: Derek Higgins <[email protected]>
1 parent 707fd0b commit 9614828

35 files changed

+15701
-9
lines changed

.github/actions/run-and-record-tests/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ runs:
7171
shell: bash
7272
run: |
7373
sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true
74+
sudo docker logs vllm > vllm-${{ inputs.inference-mode }}.log || true
7475
7576
- name: Upload logs
7677
if: ${{ always() }}

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ on:
2020
schedule:
2121
# If changing the cron schedule, update the provider in the test-matrix job
2222
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
23-
- cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC
2423
workflow_dispatch:
2524
inputs:
2625
test-all-client-versions:
@@ -47,7 +46,7 @@ jobs:
4746
matrix:
4847
client-type: [library, server]
4948
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
50-
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
49+
provider: [ollama, vllm]
5150
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
5251
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
5352
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}

.github/workflows/record-integration-tests.yml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ jobs:
3131
if: contains(github.event.pull_request.labels.*.name, 're-record-tests') ||
3232
contains(github.event.pull_request.labels.*.name, 're-record-vision-tests')
3333
runs-on: ubuntu-latest
34+
strategy:
35+
matrix:
36+
provider: [ollama, vllm]
3437
outputs:
3538
test-types: ${{ steps.generate-test-types.outputs.test-types }}
3639
matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }}
@@ -42,17 +45,21 @@ jobs:
4245
- name: Generate test types
4346
id: generate-test-types
4447
run: |
45-
# Get test directories dynamically, excluding non-test directories
46-
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
47-
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
48-
sort | jq -R -s -c 'split("\n")[:-1]')
49-
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
48+
if [ ${{ matrix.provider }} == "vllm" ]; then
49+
echo "test-types=[\"inference\"]" >> $GITHUB_OUTPUT
50+
elif
51+
# Get test directories dynamically, excluding non-test directories
52+
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
53+
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
54+
sort | jq -R -s -c 'split("\n")[:-1]')
55+
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
56+
fi
5057
5158
labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name')
5259
echo "labels=$labels"
5360
5461
modes_array=()
55-
if [[ $labels == *"re-record-vision-tests"* ]]; then
62+
if [[ $labels == *"re-record-vision-tests"* ]] && [[ ${{ matrix.provider }} == "ollama" ]]; then
5663
modes_array+=("vision")
5764
fi
5865
if [[ $labels == *"re-record-tests"* ]]; then

llama_stack/testing/inference_recorder.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,29 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
108108
try:
109109
# Import the original class and reconstruct the object
110110
module_path, class_name = data["__type__"].rsplit(".", 1)
111+
112+
# Handle generic types (e.g. AsyncPage[Model]) by removing the generic part
113+
if "[" in class_name and "]" in class_name:
114+
class_name = class_name.split("[")[0]
115+
111116
module = __import__(module_path, fromlist=[class_name])
112117
cls = getattr(module, class_name)
113118

114119
if not hasattr(cls, "model_validate"):
115120
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
116121

117-
return cls.model_validate(data["__data__"])
122+
# Special handling for AsyncPage - convert nested model dicts to proper model objects
123+
validate_data = data["__data__"]
124+
if class_name == "AsyncPage" and isinstance(validate_data, dict) and "data" in validate_data:
125+
# Convert model dictionaries to objects with attributes so they work with .id access
126+
from types import SimpleNamespace
127+
128+
validate_data = dict(validate_data)
129+
validate_data["data"] = [
130+
SimpleNamespace(**item) if isinstance(item, dict) else item for item in validate_data["data"]
131+
]
132+
133+
return cls.model_validate(validate_data)
118134
except (ImportError, AttributeError, TypeError, ValueError) as e:
119135
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
120136
return data["__data__"]
@@ -332,9 +348,11 @@ def patch_inference_clients():
332348
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
333349
from openai.resources.completions import AsyncCompletions
334350
from openai.resources.embeddings import AsyncEmbeddings
351+
from openai.resources.models import AsyncModels
335352

336353
# Store original methods for both OpenAI and Ollama clients
337354
_original_methods = {
355+
"model_list": AsyncModels.list,
338356
"chat_completions_create": AsyncChatCompletions.create,
339357
"completions_create": AsyncCompletions.create,
340358
"embeddings_create": AsyncEmbeddings.create,
@@ -347,6 +365,64 @@ def patch_inference_clients():
347365
}
348366

349367
# Create patched methods for OpenAI client
368+
def patched_model_list(self, *args, **kwargs):
369+
# The original models.list() returns an AsyncPaginator that can be used with async for
370+
# We need to create a wrapper that preserves this behavior
371+
class PatchedAsyncPaginator:
372+
def __init__(self, original_method, instance, client_type, endpoint, args, kwargs):
373+
self.original_method = original_method
374+
self.instance = instance
375+
self.client_type = client_type
376+
self.endpoint = endpoint
377+
self.args = args
378+
self.kwargs = kwargs
379+
self._result = None
380+
381+
def __await__(self):
382+
# Make it awaitable like the original AsyncPaginator
383+
async def _await():
384+
self._result = await _patched_inference_method(
385+
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
386+
)
387+
return self._result
388+
389+
return _await().__await__()
390+
391+
def __aiter__(self):
392+
# Make it async iterable like the original AsyncPaginator
393+
return self
394+
395+
async def __anext__(self):
396+
# Get the result if we haven't already
397+
if self._result is None:
398+
self._result = await _patched_inference_method(
399+
self.original_method, self.instance, self.client_type, self.endpoint, *self.args, **self.kwargs
400+
)
401+
402+
# Initialize iteration on first call
403+
if not hasattr(self, "_iter_index"):
404+
# Extract the data list from the result
405+
if hasattr(self._result, "data") and isinstance(self._result.data, list):
406+
self._data_list = self._result.data
407+
elif isinstance(self._result, list):
408+
self._data_list = self._result
409+
else:
410+
# Not a list-like response, return it once
411+
if hasattr(self, "_returned"):
412+
raise StopAsyncIteration
413+
self._returned = True
414+
return self._result
415+
self._iter_index = 0
416+
417+
# Return next item from the list
418+
if self._iter_index >= len(self._data_list):
419+
raise StopAsyncIteration
420+
item = self._data_list[self._iter_index]
421+
self._iter_index += 1
422+
return item
423+
424+
return PatchedAsyncPaginator(_original_methods["model_list"], self, "openai", "/v1/models", args, kwargs)
425+
350426
async def patched_chat_completions_create(self, *args, **kwargs):
351427
return await _patched_inference_method(
352428
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
@@ -363,6 +439,7 @@ async def patched_embeddings_create(self, *args, **kwargs):
363439
)
364440

365441
# Apply OpenAI patches
442+
AsyncModels.list = patched_model_list
366443
AsyncChatCompletions.create = patched_chat_completions_create
367444
AsyncCompletions.create = patched_completions_create
368445
AsyncEmbeddings.create = patched_embeddings_create
@@ -419,8 +496,10 @@ def unpatch_inference_clients():
419496
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
420497
from openai.resources.completions import AsyncCompletions
421498
from openai.resources.embeddings import AsyncEmbeddings
499+
from openai.resources.models import AsyncModels
422500

423501
# Restore OpenAI client methods
502+
AsyncModels.list = _original_methods["model_list"]
424503
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
425504
AsyncCompletions.create = _original_methods["completions_create"]
426505
AsyncEmbeddings.create = _original_methods["embeddings_create"]
12 KB
Binary file not shown.

0 commit comments

Comments
 (0)