Skip to content

Commit 97d409b

Browse files
authored
Fix pulling default provider name from config file and use it (#36)
1 parent bc2ed96 commit 97d409b

File tree

4 files changed

+55
-11
lines changed

4 files changed

+55
-11
lines changed

src/data_designer/config/default_model_settings.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
5+
from functools import lru_cache
46
import logging
5-
from typing import Literal
7+
from pathlib import Path
8+
from typing import Any, Literal, Optional
69

710
from .models import InferenceParameters, ModelConfig, ModelProvider
811
from .utils.constants import (
@@ -70,18 +73,19 @@ def get_default_model_configs() -> list[ModelConfig]:
7073
if MODEL_CONFIGS_FILE_PATH.exists():
7174
config_dict = load_config_file(MODEL_CONFIGS_FILE_PATH)
7275
if "model_configs" in config_dict:
73-
logger.info(f"♻️ Using default model configs from {str(MODEL_CONFIGS_FILE_PATH)!r}")
7476
return [ModelConfig.model_validate(mc) for mc in config_dict["model_configs"]]
7577
raise FileNotFoundError(f"Default model configs file not found at {str(MODEL_CONFIGS_FILE_PATH)!r}")
7678

7779

7880
def get_default_providers() -> list[ModelProvider]:
79-
if MODEL_PROVIDERS_FILE_PATH.exists():
80-
config_dict = load_config_file(MODEL_PROVIDERS_FILE_PATH)
81-
if "providers" in config_dict:
82-
logger.info(f"♻️ Using default model providers from {str(MODEL_PROVIDERS_FILE_PATH)!r}")
83-
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
84-
raise FileNotFoundError(f"Default model providers file not found at {str(MODEL_PROVIDERS_FILE_PATH)!r}")
81+
config_dict = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH)
82+
if "providers" in config_dict:
83+
return [ModelProvider.model_validate(p) for p in config_dict["providers"]]
84+
return []
85+
86+
87+
def get_default_provider_name() -> Optional[str]:
88+
return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default")
8589

8690

8791
def resolve_seed_default_model_settings() -> None:
@@ -104,3 +108,11 @@ def resolve_seed_default_model_settings() -> None:
104108
save_config_file(
105109
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
106110
)
111+
112+
113+
@lru_cache(maxsize=1)
114+
def _get_default_providers_file_content(file_path: Path) -> dict[str, Any]:
115+
"""Load and cache the default providers file content."""
116+
if file_path.exists():
117+
return load_config_file(file_path)
118+
raise FileNotFoundError(f"Default model providers file not found at {str(file_path)!r}")

src/data_designer/engine/model_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@ def get_provider(self, name: str | None) -> ModelProvider:
6464
raise UnknownProviderError(f"No provider named {name!r} registered")
6565

6666

67-
def resolve_model_provider_registry(model_providers: list[ModelProvider]) -> ModelProviderRegistry:
67+
def resolve_model_provider_registry(
68+
model_providers: list[ModelProvider], default_provider_name: str | None = None
69+
) -> ModelProviderRegistry:
6870
if len(model_providers) == 0:
6971
raise NoModelProvidersError("At least one model provider must be defined")
7072
return ModelProviderRegistry(
7173
providers=model_providers,
72-
default=model_providers[0].name,
74+
default=default_provider_name or model_providers[0].name,
7375
)

src/data_designer/interface/data_designer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from data_designer.config.config_builder import DataDesignerConfigBuilder
1111
from data_designer.config.default_model_settings import (
1212
get_default_model_configs,
13+
get_default_provider_name,
1314
get_default_providers,
1415
resolve_seed_default_model_settings,
1516
)
@@ -22,6 +23,8 @@
2223
from data_designer.config.seed import LocalSeedDatasetReference
2324
from data_designer.config.utils.constants import (
2425
DEFAULT_NUM_RECORDS,
26+
MODEL_CONFIGS_FILE_PATH,
27+
MODEL_PROVIDERS_FILE_PATH,
2528
)
2629
from data_designer.config.utils.info import InterfaceInfo
2730
from data_designer.config.utils.io_helpers import write_seed_dataset
@@ -96,7 +99,9 @@ def __init__(
9699
else init_managed_blob_storage(str(blob_storage_path))
97100
)
98101
self._model_providers = model_providers or self.get_default_model_providers()
99-
self._model_provider_registry = resolve_model_provider_registry(self._model_providers)
102+
self._model_provider_registry = resolve_model_provider_registry(
103+
self._model_providers, get_default_provider_name()
104+
)
100105

101106
@staticmethod
102107
def make_seed_reference_from_file(file_path: str | Path) -> LocalSeedDatasetReference:
@@ -248,6 +253,7 @@ def get_default_model_configs(self) -> list[ModelConfig]:
248253
Returns:
249254
List of default model configurations.
250255
"""
256+
logger.info(f"♻️ Using default model configs from {str(MODEL_CONFIGS_FILE_PATH)!r}")
251257
return get_default_model_configs()
252258

253259
def get_default_model_providers(self) -> list[ModelProvider]:
@@ -256,6 +262,7 @@ def get_default_model_providers(self) -> list[ModelProvider]:
256262
Returns:
257263
List of default model providers.
258264
"""
265+
logger.info(f"♻️ Using default model providers from {str(MODEL_PROVIDERS_FILE_PATH)!r}")
259266
return get_default_providers()
260267

261268
def set_buffer_size(self, buffer_size: int) -> None:

tests/config/test_default_model_settings.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_builtin_model_providers,
1414
get_default_inference_parameters,
1515
get_default_model_configs,
16+
get_default_provider_name,
1617
get_default_providers,
1718
resolve_seed_default_model_settings,
1819
)
@@ -95,6 +96,28 @@ def test_get_default_providers_path_does_not_exist():
9596
get_default_providers()
9697

9798

99+
def test_get_default_provider_name_with_default_key(tmp_path: Path):
100+
providers_file_path = tmp_path / "providers.yaml"
101+
providers_file_path.write_text(
102+
json.dumps(dict(providers=[p.model_dump() for p in get_builtin_model_providers()], default="nvidia"))
103+
)
104+
with patch("data_designer.config.default_model_settings.MODEL_PROVIDERS_FILE_PATH", new=providers_file_path):
105+
assert get_default_provider_name() == "nvidia"
106+
107+
108+
def test_get_default_provider_name_without_default_key(tmp_path: Path):
109+
providers_file_path = tmp_path / "providers.yaml"
110+
providers_file_path.write_text(json.dumps({"providers": [p.model_dump() for p in get_builtin_model_providers()]}))
111+
with patch("data_designer.config.default_model_settings.MODEL_PROVIDERS_FILE_PATH", new=providers_file_path):
112+
assert get_default_provider_name() is None
113+
114+
115+
def test_get_default_provider_name_path_does_not_exist():
116+
with patch("data_designer.config.default_model_settings.MODEL_PROVIDERS_FILE_PATH", new=Path("non_existent_path")):
117+
with pytest.raises(FileNotFoundError, match=r"Default model providers file not found at 'non_existent_path'"):
118+
get_default_provider_name()
119+
120+
98121
def test_get_nvidia_api_key():
99122
with patch("data_designer.config.utils.visualization.os.getenv", return_value="nvidia_api_key"):
100123
assert get_nvidia_api_key() == "nvidia_api_key"

0 commit comments

Comments
 (0)