|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the terms described in the LICENSE file in |
| 5 | +# the root directory of this source tree. |
| 6 | + |
| 7 | +from unittest.mock import MagicMock, PropertyMock, patch |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from llama_stack.apis.inference import Model |
| 12 | +from llama_stack.apis.models import ModelType |
| 13 | +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin |
| 14 | + |
| 15 | + |
| 16 | +# Test implementation of OpenAIMixin for testing purposes |
| 17 | +class OpenAIMixinImpl(OpenAIMixin): |
| 18 | + def __init__(self): |
| 19 | + self.__provider_id__ = "test-provider" |
| 20 | + |
| 21 | + def get_api_key(self) -> str: |
| 22 | + raise NotImplementedError("This method should be mocked in tests") |
| 23 | + |
| 24 | + def get_base_url(self) -> str: |
| 25 | + raise NotImplementedError("This method should be mocked in tests") |
| 26 | + |
| 27 | + |
| 28 | +@pytest.fixture |
| 29 | +def mixin(): |
| 30 | + """Create a test instance of OpenAIMixin""" |
| 31 | + return OpenAIMixinImpl() |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture |
| 35 | +def mock_models(): |
| 36 | + """Create multiple mock OpenAI model objects""" |
| 37 | + models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]] |
| 38 | + return models |
| 39 | + |
| 40 | + |
| 41 | +@pytest.fixture |
| 42 | +def mock_client_with_models(mock_models): |
| 43 | + """Create a mock client with models.list() set up to return mock_models""" |
| 44 | + mock_client = MagicMock() |
| 45 | + |
| 46 | + async def mock_models_list(): |
| 47 | + for model in mock_models: |
| 48 | + yield model |
| 49 | + |
| 50 | + mock_client.models.list.return_value = mock_models_list() |
| 51 | + return mock_client |
| 52 | + |
| 53 | + |
| 54 | +@pytest.fixture |
| 55 | +def mock_client_with_empty_models(): |
| 56 | + """Create a mock client with models.list() set up to return empty list""" |
| 57 | + mock_client = MagicMock() |
| 58 | + |
| 59 | + async def mock_empty_models_list(): |
| 60 | + return |
| 61 | + yield # Make it an async generator but don't yield anything |
| 62 | + |
| 63 | + mock_client.models.list.return_value = mock_empty_models_list() |
| 64 | + return mock_client |
| 65 | + |
| 66 | + |
| 67 | +@pytest.fixture |
| 68 | +def mock_client_with_exception(): |
| 69 | + """Create a mock client with models.list() set up to raise an exception""" |
| 70 | + mock_client = MagicMock() |
| 71 | + mock_client.models.list.side_effect = Exception("API Error") |
| 72 | + return mock_client |
| 73 | + |
| 74 | + |
| 75 | +@pytest.fixture |
| 76 | +def mock_client_context(): |
| 77 | + """Fixture that provides a context manager for mocking the OpenAI client""" |
| 78 | + |
| 79 | + def _mock_client_context(mixin, mock_client): |
| 80 | + return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client) |
| 81 | + |
| 82 | + return _mock_client_context |
| 83 | + |
| 84 | + |
| 85 | +class TestOpenAIMixinListModels: |
| 86 | + """Test cases for the list_models method""" |
| 87 | + |
| 88 | + async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context): |
| 89 | + """Test successful model listing""" |
| 90 | + assert len(mixin._model_cache) == 0 |
| 91 | + |
| 92 | + with mock_client_context(mixin, mock_client_with_models): |
| 93 | + result = await mixin.list_models() |
| 94 | + |
| 95 | + assert result is not None |
| 96 | + assert len(result) == 3 |
| 97 | + |
| 98 | + model_ids = [model.identifier for model in result] |
| 99 | + assert "some-mock-model-id" in model_ids |
| 100 | + assert "another-mock-model-id" in model_ids |
| 101 | + assert "final-mock-model-id" in model_ids |
| 102 | + |
| 103 | + for model in result: |
| 104 | + assert model.provider_id == "test-provider" |
| 105 | + assert model.model_type == ModelType.llm |
| 106 | + assert model.provider_resource_id == model.identifier |
| 107 | + |
| 108 | + assert len(mixin._model_cache) == 3 |
| 109 | + for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]: |
| 110 | + assert model_id in mixin._model_cache |
| 111 | + cached_model = mixin._model_cache[model_id] |
| 112 | + assert cached_model.identifier == model_id |
| 113 | + assert cached_model.provider_resource_id == model_id |
| 114 | + |
| 115 | + async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context): |
| 116 | + """Test handling of empty model list""" |
| 117 | + with mock_client_context(mixin, mock_client_with_empty_models): |
| 118 | + result = await mixin.list_models() |
| 119 | + |
| 120 | + assert result is not None |
| 121 | + assert len(result) == 0 |
| 122 | + assert len(mixin._model_cache) == 0 |
| 123 | + |
| 124 | + |
| 125 | +class TestOpenAIMixinCheckModelAvailability: |
| 126 | + """Test cases for the check_model_availability method""" |
| 127 | + |
| 128 | + async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context): |
| 129 | + """Test model availability check when cache is populated""" |
| 130 | + with mock_client_context(mixin, mock_client_with_models): |
| 131 | + mock_client_with_models.models.list.assert_not_called() |
| 132 | + await mixin.list_models() |
| 133 | + mock_client_with_models.models.list.assert_called_once() |
| 134 | + |
| 135 | + assert await mixin.check_model_availability("some-mock-model-id") |
| 136 | + assert await mixin.check_model_availability("another-mock-model-id") |
| 137 | + assert await mixin.check_model_availability("final-mock-model-id") |
| 138 | + assert not await mixin.check_model_availability("non-existent-model") |
| 139 | + mock_client_with_models.models.list.assert_called_once() |
| 140 | + |
| 141 | + async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context): |
| 142 | + """Test model availability check when cache is empty (calls list_models)""" |
| 143 | + assert len(mixin._model_cache) == 0 |
| 144 | + |
| 145 | + with mock_client_context(mixin, mock_client_with_models): |
| 146 | + mock_client_with_models.models.list.assert_not_called() |
| 147 | + assert await mixin.check_model_availability("some-mock-model-id") |
| 148 | + mock_client_with_models.models.list.assert_called_once() |
| 149 | + |
| 150 | + assert len(mixin._model_cache) == 3 |
| 151 | + assert "some-mock-model-id" in mixin._model_cache |
| 152 | + |
| 153 | + async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context): |
| 154 | + """Test model availability check for non-existent model""" |
| 155 | + with mock_client_context(mixin, mock_client_with_models): |
| 156 | + mock_client_with_models.models.list.assert_not_called() |
| 157 | + assert not await mixin.check_model_availability("non-existent-model") |
| 158 | + mock_client_with_models.models.list.assert_called_once() |
| 159 | + |
| 160 | + assert len(mixin._model_cache) == 3 |
| 161 | + |
| 162 | + |
| 163 | +class TestOpenAIMixinCacheBehavior: |
| 164 | + """Test cases for cache behavior and edge cases""" |
| 165 | + |
| 166 | + async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context): |
| 167 | + """Test that calling list_models overwrites existing cache""" |
| 168 | + initial_model = Model( |
| 169 | + provider_id="test-provider", |
| 170 | + provider_resource_id="old-model", |
| 171 | + identifier="old-model", |
| 172 | + model_type=ModelType.llm, |
| 173 | + ) |
| 174 | + mixin._model_cache = {"old-model": initial_model} |
| 175 | + |
| 176 | + with mock_client_context(mixin, mock_client_with_models): |
| 177 | + await mixin.list_models() |
| 178 | + |
| 179 | + assert len(mixin._model_cache) == 3 |
| 180 | + assert "old-model" not in mixin._model_cache |
| 181 | + assert "some-mock-model-id" in mixin._model_cache |
| 182 | + assert "another-mock-model-id" in mixin._model_cache |
| 183 | + assert "final-mock-model-id" in mixin._model_cache |
0 commit comments