Skip to content

Commit 23e659c

Browse files
authored
fix (#238) non api key warning on default model providers
1 parent 0ab3613 commit 23e659c

File tree

3 files changed

+96
-16
lines changed

3 files changed

+96
-16
lines changed

src/data_designer/config/default_model_settings.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,20 @@ def get_default_model_configs() -> list[ModelConfig]:
7171
return []
7272

7373

74-
def get_default_model_providers_missing_api_keys() -> list[str]:
75-
missing_api_keys = []
76-
for predefined_provider in PREDEFINED_PROVIDERS:
77-
if os.environ.get(predefined_provider["api_key"]) is None:
78-
missing_api_keys.append(predefined_provider["api_key"])
79-
return missing_api_keys
74+
def get_providers_with_missing_api_keys(providers: list[ModelProvider]) -> list[ModelProvider]:
75+
providers_with_missing_keys = []
76+
77+
for provider in providers:
78+
if provider.api_key is None:
79+
# No API key specified at all
80+
providers_with_missing_keys.append(provider)
81+
elif provider.api_key.isupper() and "_" in provider.api_key:
82+
# Looks like an environment variable name, check if it's set
83+
if os.environ.get(provider.api_key) is None:
84+
providers_with_missing_keys.append(provider)
85+
# else: It's an actual API key value (not an env var), so it's valid
86+
87+
return providers_with_missing_keys
8088

8189

8290
def get_default_providers() -> list[ModelProvider]:

src/data_designer/interface/data_designer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from data_designer.config.data_designer_config import DataDesignerConfig
1313
from data_designer.config.default_model_settings import (
1414
get_default_model_configs,
15-
get_default_model_providers_missing_api_keys,
1615
get_default_provider_name,
1716
get_default_providers,
17+
get_providers_with_missing_api_keys,
1818
)
1919
from data_designer.config.interface import DataDesignerInterface
2020
from data_designer.config.models import (
@@ -28,7 +28,6 @@
2828
MANAGED_ASSETS_PATH,
2929
MODEL_CONFIGS_FILE_PATH,
3030
MODEL_PROVIDERS_FILE_PATH,
31-
PREDEFINED_PROVIDERS,
3231
)
3332
from data_designer.config.utils.info import InfoType, InterfaceInfo
3433
from data_designer.engine.analysis.dataset_profiler import DataDesignerDatasetProfiler, DatasetProfilerConfig
@@ -334,8 +333,11 @@ def set_run_config(self, run_config: RunConfig) -> None:
334333
def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
335334
if model_providers is None:
336335
model_providers = get_default_providers()
337-
missing_api_keys = get_default_model_providers_missing_api_keys()
338-
if len(missing_api_keys) == len(PREDEFINED_PROVIDERS):
336+
# Check which providers have missing API keys (from YAML file or env vars)
337+
providers_with_missing_keys = get_providers_with_missing_api_keys(model_providers)
338+
339+
if len(providers_with_missing_keys) == len(model_providers):
340+
# All providers have missing API keys
339341
logger.warning(
340342
"🚨 You are trying to use a default model provider but your API keys are missing."
341343
"\n\t\t\tSet the API key for the default providers you intend to use and re-initialize the Data Designer object."

tests/config/test_default_model_settings.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
get_builtin_model_providers,
1414
get_default_inference_parameters,
1515
get_default_model_configs,
16-
get_default_model_providers_missing_api_keys,
1716
get_default_provider_name,
1817
get_default_providers,
18+
get_providers_with_missing_api_keys,
1919
resolve_seed_default_model_settings,
2020
)
21-
from data_designer.config.models import ChatCompletionInferenceParams, EmbeddingInferenceParams
21+
from data_designer.config.models import ChatCompletionInferenceParams, EmbeddingInferenceParams, ModelProvider
2222
from data_designer.config.utils.visualization import get_nvidia_api_key, get_openai_api_key
2323

2424

@@ -190,7 +190,77 @@ def test_resolve_seed_default_model_settings(tmp_path: Path):
190190
assert providers_data == {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
191191

192192

193-
@patch("data_designer.config.default_model_settings.os.environ.get")
194-
def test_get_default_model_providers_missing_api_keys(mock_environ_get):
195-
mock_environ_get.return_value = None
196-
assert get_default_model_providers_missing_api_keys() == ["NVIDIA_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY"]
193+
def test_get_providers_with_missing_api_keys():
194+
"""Test detection of providers with missing API keys."""
195+
# Test providers with various API key configurations
196+
providers = [
197+
ModelProvider(name="provider1", endpoint="http://test1.com", api_key="NVIDIA_API_KEY"), # env var
198+
ModelProvider(name="provider2", endpoint="http://test2.com", api_key="sk-actual-key-12345"), # actual key
199+
ModelProvider(name="provider3", endpoint="http://test3.com", api_key=None), # no key
200+
]
201+
202+
with patch("data_designer.config.default_model_settings.os.environ.get") as mock_env:
203+
# Mock env to have NVIDIA_API_KEY set but not MISSING_VAR
204+
def mock_get(key: str) -> str | None:
205+
return "test-key" if key == "NVIDIA_API_KEY" else None
206+
207+
mock_env.side_effect = mock_get
208+
209+
missing = get_providers_with_missing_api_keys(providers)
210+
211+
# provider1 has env var set -> OK
212+
# provider2 has actual API key -> OK
213+
# provider3 has no API key -> MISSING
214+
assert len(missing) == 1
215+
assert missing[0].name == "provider3"
216+
217+
218+
def test_get_providers_with_missing_api_keys_env_var_not_set():
219+
"""Test detection when environment variable is not set."""
220+
providers = [
221+
ModelProvider(name="provider1", endpoint="http://test1.com", api_key="MISSING_ENV_VAR"),
222+
]
223+
224+
with patch("data_designer.config.default_model_settings.os.environ.get", return_value=None):
225+
missing = get_providers_with_missing_api_keys(providers)
226+
assert len(missing) == 1
227+
assert missing[0].name == "provider1"
228+
229+
230+
def test_get_providers_with_missing_api_keys_all_valid():
231+
"""Test when all providers have valid API keys."""
232+
providers = [
233+
ModelProvider(name="provider1", endpoint="http://test1.com", api_key="sk-actual-key-1"),
234+
ModelProvider(name="provider2", endpoint="http://test2.com", api_key="sk-actual-key-2"),
235+
]
236+
237+
missing = get_providers_with_missing_api_keys(providers)
238+
assert len(missing) == 0
239+
240+
241+
def test_get_providers_with_missing_api_keys_all_missing():
242+
"""Test when all providers have missing API keys."""
243+
providers = [
244+
ModelProvider(name="provider1", endpoint="http://test1.com", api_key="MISSING_VAR_1"),
245+
ModelProvider(name="provider2", endpoint="http://test2.com", api_key=None),
246+
]
247+
248+
with patch("data_designer.config.default_model_settings.os.environ.get", return_value=None):
249+
missing = get_providers_with_missing_api_keys(providers)
250+
assert len(missing) == 2
251+
assert {p.name for p in missing} == {"provider1", "provider2"}
252+
253+
254+
def test_get_providers_with_missing_api_keys_mixed_case():
255+
"""Test that lowercase API keys are treated as actual keys, not env vars."""
256+
providers = [
257+
ModelProvider(name="provider1", endpoint="http://test1.com", api_key="lowercase_key"),
258+
ModelProvider(name="provider2", endpoint="http://test2.com", api_key="UPPERCASE_KEY"),
259+
]
260+
261+
with patch("data_designer.config.default_model_settings.os.environ.get", return_value=None):
262+
missing = get_providers_with_missing_api_keys(providers)
263+
# provider1 has lowercase key (treated as actual key) -> OK
264+
# provider2 has uppercase key but env var not set -> MISSING
265+
assert len(missing) == 1
266+
assert missing[0].name == "provider2"

0 commit comments

Comments
 (0)