Skip to content

Commit c2f2ca0

Browse files
mattfiamemilio
authored andcommitted
feat: include all models from provider's /v1/models (llamastack#3471)
# What does this PR do? this replaces the static model listing for any provider using OpenAIMixin currently - - anthropic - azure openai - gemini - groq - llama-api - nvidia - openai - sambanova - tgi - vertexai - vllm - not changed: together has its own impl ## Test Plan - new unit tests - manual for llama-api, openai, groq, gemini ``` for provider in llama-openai-compat openai groq gemini; do uv run llama stack build --image-type venv --providers inference=remote::provider --run & uv run --with llama-stack-client llama-stack-client models list | grep Total ``` results (17 sep 2025): - llama-api: 4 - openai: 86 - groq: 21 - gemini: 66 closes llamastack#3467
1 parent 1ab51e1 commit c2f2ca0

File tree

3 files changed

+243
-21
lines changed

3 files changed

+243
-21
lines changed

llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import AsyncIterator
1010
from typing import Any
1111

12-
import openai
1312
from openai import NOT_GIVEN, AsyncOpenAI
1413

1514
from llama_stack.apis.inference import (
@@ -23,6 +22,7 @@
2322
OpenAIMessageParam,
2423
OpenAIResponseFormatParam,
2524
)
25+
from llama_stack.apis.models import ModelType
2626
from llama_stack.log import get_logger
2727
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
2828

@@ -50,6 +50,10 @@ class OpenAIMixin(ABC):
5050
# This is useful for providers that do not return a unique id in the response.
5151
overwrite_completion_id: bool = False
5252

53+
# Cache of available models keyed by model ID
54+
# This is set in list_models() and used in check_model_availability()
55+
_model_cache: dict[str, Model] = {}
56+
5357
@abstractmethod
5458
def get_api_key(self) -> str:
5559
"""
@@ -296,22 +300,35 @@ async def openai_embeddings(
296300
usage=usage,
297301
)
298302

303+
async def list_models(self) -> list[Model] | None:
304+
"""
305+
List available models from the provider's /v1/models endpoint.
306+
307+
Also, caches the models in self._model_cache for use in check_model_availability().
308+
309+
:return: A list of Model instances representing available models.
310+
"""
311+
self._model_cache = {
312+
m.id: Model(
313+
# __provider_id__ is dynamically added by instantiate_provider in resolver.py
314+
provider_id=self.__provider_id__, # type: ignore[attr-defined]
315+
provider_resource_id=m.id,
316+
identifier=m.id,
317+
model_type=ModelType.llm,
318+
)
319+
async for m in self.client.models.list()
320+
}
321+
322+
return list(self._model_cache.values())
323+
299324
async def check_model_availability(self, model: str) -> bool:
300325
"""
301-
Check if a specific model is available from OpenAI.
326+
Check if a specific model is available from the provider's /v1/models.
302327
303328
:param model: The model identifier to check.
304329
:return: True if the model is available dynamically, False otherwise.
305330
"""
306-
try:
307-
# Direct model lookup - returns model or raises NotFoundError
308-
await self.client.models.retrieve(model)
309-
return True
310-
except openai.NotFoundError:
311-
# Model doesn't exist - this is expected for unavailable models
312-
pass
313-
except Exception as e:
314-
# All other errors (auth, rate limit, network, etc.)
315-
logger.warning(f"Failed to check model availability for {model}: {e}")
316-
317-
return False
331+
if not self._model_cache:
332+
await self.list_models()
333+
334+
return model in self._model_cache

tests/unit/providers/inference/test_openai_base_url_config.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# the root directory of this source tree.
66

77
import os
8-
from unittest.mock import AsyncMock, MagicMock, patch
8+
from unittest.mock import MagicMock, patch
99

1010
from llama_stack.core.stack import replace_env_vars
1111
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
@@ -80,11 +80,22 @@ async def test_check_model_availability_uses_configured_url(self, mock_openai_cl
8080
# Mock the get_api_key method
8181
adapter.get_api_key = MagicMock(return_value="test-key")
8282

83-
# Mock the AsyncOpenAI client and its models.retrieve method
83+
# Mock a model object that will be returned by models.list()
84+
mock_model = MagicMock()
85+
mock_model.id = "gpt-4"
86+
87+
# Create an async iterator that yields our mock model
88+
async def mock_async_iterator():
89+
yield mock_model
90+
91+
# Mock the AsyncOpenAI client and its models.list method
8492
mock_client = MagicMock()
85-
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
93+
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
8694
mock_openai_class.return_value = mock_client
8795

96+
# Set the __provider_id__ attribute that's expected by list_models
97+
adapter.__provider_id__ = "openai"
98+
8899
# Call check_model_availability and verify it returns True
89100
assert await adapter.check_model_availability("gpt-4")
90101

@@ -94,8 +105,8 @@ async def test_check_model_availability_uses_configured_url(self, mock_openai_cl
94105
base_url=custom_url,
95106
)
96107

97-
# Verify the method was called and returned True
98-
mock_client.models.retrieve.assert_called_once_with("gpt-4")
108+
# Verify the models.list method was called
109+
mock_client.models.list.assert_called_once()
99110

100111
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
101112
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
@@ -110,11 +121,22 @@ async def test_environment_variable_affects_model_availability_check(self, mock_
110121
# Mock the get_api_key method
111122
adapter.get_api_key = MagicMock(return_value="test-key")
112123

113-
# Mock the AsyncOpenAI client
124+
# Mock a model object that will be returned by models.list()
125+
mock_model = MagicMock()
126+
mock_model.id = "gpt-4"
127+
128+
# Create an async iterator that yields our mock model
129+
async def mock_async_iterator():
130+
yield mock_model
131+
132+
# Mock the AsyncOpenAI client and its models.list method
114133
mock_client = MagicMock()
115-
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
134+
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
116135
mock_openai_class.return_value = mock_client
117136

137+
# Set the __provider_id__ attribute that's expected by list_models
138+
adapter.__provider_id__ = "openai"
139+
118140
# Call check_model_availability and verify it returns True
119141
assert await adapter.check_model_availability("gpt-4")
120142

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from unittest.mock import MagicMock, PropertyMock, patch
8+
9+
import pytest
10+
11+
from llama_stack.apis.inference import Model
12+
from llama_stack.apis.models import ModelType
13+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
14+
15+
16+
# Test implementation of OpenAIMixin for testing purposes
17+
class OpenAIMixinImpl(OpenAIMixin):
18+
def __init__(self):
19+
self.__provider_id__ = "test-provider"
20+
21+
def get_api_key(self) -> str:
22+
raise NotImplementedError("This method should be mocked in tests")
23+
24+
def get_base_url(self) -> str:
25+
raise NotImplementedError("This method should be mocked in tests")
26+
27+
28+
@pytest.fixture
29+
def mixin():
30+
"""Create a test instance of OpenAIMixin"""
31+
return OpenAIMixinImpl()
32+
33+
34+
@pytest.fixture
35+
def mock_models():
36+
"""Create multiple mock OpenAI model objects"""
37+
models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]]
38+
return models
39+
40+
41+
@pytest.fixture
42+
def mock_client_with_models(mock_models):
43+
"""Create a mock client with models.list() set up to return mock_models"""
44+
mock_client = MagicMock()
45+
46+
async def mock_models_list():
47+
for model in mock_models:
48+
yield model
49+
50+
mock_client.models.list.return_value = mock_models_list()
51+
return mock_client
52+
53+
54+
@pytest.fixture
55+
def mock_client_with_empty_models():
56+
"""Create a mock client with models.list() set up to return empty list"""
57+
mock_client = MagicMock()
58+
59+
async def mock_empty_models_list():
60+
return
61+
yield # Make it an async generator but don't yield anything
62+
63+
mock_client.models.list.return_value = mock_empty_models_list()
64+
return mock_client
65+
66+
67+
@pytest.fixture
68+
def mock_client_with_exception():
69+
"""Create a mock client with models.list() set up to raise an exception"""
70+
mock_client = MagicMock()
71+
mock_client.models.list.side_effect = Exception("API Error")
72+
return mock_client
73+
74+
75+
@pytest.fixture
76+
def mock_client_context():
77+
"""Fixture that provides a context manager for mocking the OpenAI client"""
78+
79+
def _mock_client_context(mixin, mock_client):
80+
return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client)
81+
82+
return _mock_client_context
83+
84+
85+
class TestOpenAIMixinListModels:
86+
"""Test cases for the list_models method"""
87+
88+
async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context):
89+
"""Test successful model listing"""
90+
assert len(mixin._model_cache) == 0
91+
92+
with mock_client_context(mixin, mock_client_with_models):
93+
result = await mixin.list_models()
94+
95+
assert result is not None
96+
assert len(result) == 3
97+
98+
model_ids = [model.identifier for model in result]
99+
assert "some-mock-model-id" in model_ids
100+
assert "another-mock-model-id" in model_ids
101+
assert "final-mock-model-id" in model_ids
102+
103+
for model in result:
104+
assert model.provider_id == "test-provider"
105+
assert model.model_type == ModelType.llm
106+
assert model.provider_resource_id == model.identifier
107+
108+
assert len(mixin._model_cache) == 3
109+
for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]:
110+
assert model_id in mixin._model_cache
111+
cached_model = mixin._model_cache[model_id]
112+
assert cached_model.identifier == model_id
113+
assert cached_model.provider_resource_id == model_id
114+
115+
async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context):
116+
"""Test handling of empty model list"""
117+
with mock_client_context(mixin, mock_client_with_empty_models):
118+
result = await mixin.list_models()
119+
120+
assert result is not None
121+
assert len(result) == 0
122+
assert len(mixin._model_cache) == 0
123+
124+
125+
class TestOpenAIMixinCheckModelAvailability:
126+
"""Test cases for the check_model_availability method"""
127+
128+
async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context):
129+
"""Test model availability check when cache is populated"""
130+
with mock_client_context(mixin, mock_client_with_models):
131+
mock_client_with_models.models.list.assert_not_called()
132+
await mixin.list_models()
133+
mock_client_with_models.models.list.assert_called_once()
134+
135+
assert await mixin.check_model_availability("some-mock-model-id")
136+
assert await mixin.check_model_availability("another-mock-model-id")
137+
assert await mixin.check_model_availability("final-mock-model-id")
138+
assert not await mixin.check_model_availability("non-existent-model")
139+
mock_client_with_models.models.list.assert_called_once()
140+
141+
async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context):
142+
"""Test model availability check when cache is empty (calls list_models)"""
143+
assert len(mixin._model_cache) == 0
144+
145+
with mock_client_context(mixin, mock_client_with_models):
146+
mock_client_with_models.models.list.assert_not_called()
147+
assert await mixin.check_model_availability("some-mock-model-id")
148+
mock_client_with_models.models.list.assert_called_once()
149+
150+
assert len(mixin._model_cache) == 3
151+
assert "some-mock-model-id" in mixin._model_cache
152+
153+
async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context):
154+
"""Test model availability check for non-existent model"""
155+
with mock_client_context(mixin, mock_client_with_models):
156+
mock_client_with_models.models.list.assert_not_called()
157+
assert not await mixin.check_model_availability("non-existent-model")
158+
mock_client_with_models.models.list.assert_called_once()
159+
160+
assert len(mixin._model_cache) == 3
161+
162+
163+
class TestOpenAIMixinCacheBehavior:
164+
"""Test cases for cache behavior and edge cases"""
165+
166+
async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context):
167+
"""Test that calling list_models overwrites existing cache"""
168+
initial_model = Model(
169+
provider_id="test-provider",
170+
provider_resource_id="old-model",
171+
identifier="old-model",
172+
model_type=ModelType.llm,
173+
)
174+
mixin._model_cache = {"old-model": initial_model}
175+
176+
with mock_client_context(mixin, mock_client_with_models):
177+
await mixin.list_models()
178+
179+
assert len(mixin._model_cache) == 3
180+
assert "old-model" not in mixin._model_cache
181+
assert "some-mock-model-id" in mixin._model_cache
182+
assert "another-mock-model-id" in mixin._model_cache
183+
assert "final-mock-model-id" in mixin._model_cache

0 commit comments

Comments
 (0)