|
13 | 13 | get_builtin_model_providers, |
14 | 14 | get_default_inference_parameters, |
15 | 15 | get_default_model_configs, |
16 | | - get_default_model_providers_missing_api_keys, |
17 | 16 | get_default_provider_name, |
18 | 17 | get_default_providers, |
| 18 | + get_providers_with_missing_api_keys, |
19 | 19 | resolve_seed_default_model_settings, |
20 | 20 | ) |
21 | | -from data_designer.config.models import ChatCompletionInferenceParams, EmbeddingInferenceParams |
| 21 | +from data_designer.config.models import ChatCompletionInferenceParams, EmbeddingInferenceParams, ModelProvider |
22 | 22 | from data_designer.config.utils.visualization import get_nvidia_api_key, get_openai_api_key |
23 | 23 |
|
24 | 24 |
|
@@ -190,7 +190,77 @@ def test_resolve_seed_default_model_settings(tmp_path: Path): |
190 | 190 | assert providers_data == {"providers": [p.model_dump() for p in get_builtin_model_providers()]} |
191 | 191 |
|
192 | 192 |
|
193 | | -@patch("data_designer.config.default_model_settings.os.environ.get") |
194 | | -def test_get_default_model_providers_missing_api_keys(mock_environ_get): |
195 | | - mock_environ_get.return_value = None |
196 | | - assert get_default_model_providers_missing_api_keys() == ["NVIDIA_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"] |
| 193 | +def test_get_providers_with_missing_api_keys(): |
| 194 | + """Test detection of providers with missing API keys.""" |
| 195 | + # Test providers with various API key configurations |
| 196 | + providers = [ |
| 197 | + ModelProvider(name="provider1", endpoint="http://test1.com", api_key="NVIDIA_API_KEY"), # env var |
| 198 | + ModelProvider(name="provider2", endpoint="http://test2.com", api_key="sk-actual-key-12345"), # actual key |
| 199 | + ModelProvider(name="provider3", endpoint="http://test3.com", api_key=None), # no key |
| 200 | + ] |
| 201 | + |
| 202 | + with patch("data_designer.config.default_model_settings.os.environ.get") as mock_env: |
| 203 | + # Mock env to have NVIDIA_API_KEY set but not MISSING_VAR |
| 204 | + def mock_get(key: str) -> str | None: |
| 205 | + return "test-key" if key == "NVIDIA_API_KEY" else None |
| 206 | + |
| 207 | + mock_env.side_effect = mock_get |
| 208 | + |
| 209 | + missing = get_providers_with_missing_api_keys(providers) |
| 210 | + |
| 211 | + # provider1 has env var set -> OK |
| 212 | + # provider2 has actual API key -> OK |
| 213 | + # provider3 has no API key -> MISSING |
| 214 | + assert len(missing) == 1 |
| 215 | + assert missing[0].name == "provider3" |
| 216 | + |
| 217 | + |
| 218 | +def test_get_providers_with_missing_api_keys_env_var_not_set(): |
| 219 | + """Test detection when environment variable is not set.""" |
| 220 | + providers = [ |
| 221 | + ModelProvider(name="provider1", endpoint="http://test1.com", api_key="MISSING_ENV_VAR"), |
| 222 | + ] |
| 223 | + |
| 224 | + with patch("data_designer.config.default_model_settings.os.environ.get", return_value=None): |
| 225 | + missing = get_providers_with_missing_api_keys(providers) |
| 226 | + assert len(missing) == 1 |
| 227 | + assert missing[0].name == "provider1" |
| 228 | + |
| 229 | + |
| 230 | +def test_get_providers_with_missing_api_keys_all_valid(): |
| 231 | + """Test when all providers have valid API keys.""" |
| 232 | + providers = [ |
| 233 | + ModelProvider(name="provider1", endpoint="http://test1.com", api_key="sk-actual-key-1"), |
| 234 | + ModelProvider(name="provider2", endpoint="http://test2.com", api_key="sk-actual-key-2"), |
| 235 | + ] |
| 236 | + |
| 237 | + missing = get_providers_with_missing_api_keys(providers) |
| 238 | + assert len(missing) == 0 |
| 239 | + |
| 240 | + |
| 241 | +def test_get_providers_with_missing_api_keys_all_missing(): |
| 242 | + """Test when all providers have missing API keys.""" |
| 243 | + providers = [ |
| 244 | + ModelProvider(name="provider1", endpoint="http://test1.com", api_key="MISSING_VAR_1"), |
| 245 | + ModelProvider(name="provider2", endpoint="http://test2.com", api_key=None), |
| 246 | + ] |
| 247 | + |
| 248 | + with patch("data_designer.config.default_model_settings.os.environ.get", return_value=None): |
| 249 | + missing = get_providers_with_missing_api_keys(providers) |
| 250 | + assert len(missing) == 2 |
| 251 | + assert {p.name for p in missing} == {"provider1", "provider2"} |
| 252 | + |
| 253 | + |
| 254 | +def test_get_providers_with_missing_api_keys_mixed_case(): |
| 255 | + """Test that lowercase API keys are treated as actual keys, not env vars.""" |
| 256 | + providers = [ |
| 257 | + ModelProvider(name="provider1", endpoint="http://test1.com", api_key="lowercase_key"), |
| 258 | + ModelProvider(name="provider2", endpoint="http://test2.com", api_key="UPPERCASE_KEY"), |
| 259 | + ] |
| 260 | + |
| 261 | + with patch("data_designer.config.default_model_settings.os.environ.get", return_value=None): |
| 262 | + missing = get_providers_with_missing_api_keys(providers) |
| 263 | + # provider1 has lowercase key (treated as actual key) -> OK |
| 264 | + # provider2 has uppercase key but env var not set -> MISSING |
| 265 | + assert len(missing) == 1 |
| 266 | + assert missing[0].name == "provider2" |
0 commit comments