Skip to content

Commit 9a33192

Browse files
authored
chore: clean up ftue with default model provider errors (#66)
* clean up ftue with default model provider errors * need to return an array though empty in _resolve_model_providers * Spread err logs to multiple lines
1 parent 7a199c3 commit 9a33192

File tree

3 files changed

+42
-10
lines changed

3 files changed

+42
-10
lines changed

src/data_designer/config/default_model_settings.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from functools import lru_cache
66
import logging
7+
import os
78
from pathlib import Path
89
from typing import Any, Literal, Optional
910

@@ -15,7 +16,6 @@
1516
PREDEFINED_PROVIDERS,
1617
PREDEFINED_PROVIDERS_MODEL_MAP,
1718
)
18-
from .utils.info import ConfigBuilderInfo, InfoType, InterfaceInfo
1919
from .utils.io_helpers import load_config_file, save_config_file
2020

2121
logger = logging.getLogger(__name__)
@@ -78,6 +78,14 @@ def get_default_model_configs() -> list[ModelConfig]:
7878
return []
7979

8080

81+
def get_defaul_model_providers_missing_api_keys() -> list[str]:
82+
missing_api_keys = []
83+
for predefined_provider in PREDEFINED_PROVIDERS:
84+
if os.environ.get(predefined_provider["api_key"]) is None:
85+
missing_api_keys.append(predefined_provider["api_key"])
86+
return missing_api_keys
87+
88+
8189
def get_default_providers() -> list[ModelProvider]:
8290
config_dict = _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH)
8391
if "providers" in config_dict:
@@ -91,21 +99,17 @@ def get_default_provider_name() -> Optional[str]:
9199

92100
def resolve_seed_default_model_settings() -> None:
93101
if not MODEL_CONFIGS_FILE_PATH.exists():
94-
logger.info(
102+
logger.debug(
95103
f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
96104
)
97-
config_builder_info = ConfigBuilderInfo(model_configs=get_builtin_model_configs())
98-
config_builder_info.display(info_type=InfoType.MODEL_CONFIGS)
99105
save_config_file(
100106
MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
101107
)
102108

103109
if not MODEL_PROVIDERS_FILE_PATH.exists():
104-
logger.info(
110+
logger.debug(
105111
f"🪄 Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
106112
)
107-
interface_info = InterfaceInfo(model_providers=get_builtin_model_providers())
108-
interface_info.display(info_type=InfoType.MODEL_PROVIDERS)
109113
save_config_file(
110114
MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
111115
)

src/data_designer/interface/data_designer.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
1010
from data_designer.config.config_builder import DataDesignerConfigBuilder
1111
from data_designer.config.default_model_settings import (
12+
get_defaul_model_providers_missing_api_keys,
1213
get_default_model_configs,
1314
get_default_provider_name,
1415
get_default_providers,
@@ -26,8 +27,9 @@
2627
MANAGED_ASSETS_PATH,
2728
MODEL_CONFIGS_FILE_PATH,
2829
MODEL_PROVIDERS_FILE_PATH,
30+
PREDEFINED_PROVIDERS,
2931
)
30-
from data_designer.config.utils.info import InterfaceInfo
32+
from data_designer.config.utils.info import InfoType, InterfaceInfo
3133
from data_designer.config.utils.io_helpers import write_seed_dataset
3234
from data_designer.config.utils.misc import can_run_data_designer_locally
3335
from data_designer.engine.analysis.dataset_profiler import (
@@ -103,7 +105,7 @@ def __init__(
103105
self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts"
104106
self._buffer_size = DEFAULT_BUFFER_SIZE
105107
self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH)
106-
self._model_providers = model_providers or self.get_default_model_providers()
108+
self._model_providers = self._resolve_model_providers(model_providers)
107109
self._model_provider_registry = resolve_model_provider_registry(
108110
self._model_providers, get_default_provider_name()
109111
)
@@ -151,7 +153,7 @@ def info(self) -> InterfaceInfo:
151153
Returns:
152154
InterfaceInfo object with information about the Data Designer interface.
153155
"""
154-
return InterfaceInfo(model_providers=self._model_providers)
156+
return self._get_interface_info(self._model_providers)
155157

156158
def create(
157159
self,
@@ -307,6 +309,22 @@ def set_buffer_size(self, buffer_size: int) -> None:
307309
raise InvalidBufferValueError("Buffer size must be greater than 0.")
308310
self._buffer_size = buffer_size
309311

312+
def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]:
313+
if model_providers is None:
314+
if can_run_data_designer_locally():
315+
model_providers = get_default_providers()
316+
missing_api_keys = get_defaul_model_providers_missing_api_keys()
317+
if len(missing_api_keys) == len(PREDEFINED_PROVIDERS):
318+
logger.warning(
319+
"🚨 You are trying to use a default model provider but your API keys are missing."
320+
"\n\t\t\tSet the API key for the default providers you intend to use and re-initialize the Data Designer object."
321+
"\n\t\t\tAlternatively, you can provide your own model providers during Data Designer object initialization."
322+
"\n\t\t\tSee https://nvidia-nemo.github.io/DataDesigner/models/model-providers/ for more information."
323+
)
324+
self._get_interface_info(model_providers).display(InfoType.MODEL_PROVIDERS)
325+
return model_providers
326+
return model_providers or []
327+
310328
def _create_dataset_builder(
311329
self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider
312330
) -> ColumnWiseDatasetBuilder:
@@ -349,3 +367,6 @@ def _create_resource_provider(
349367
)
350368
),
351369
)
370+
371+
def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo:
372+
return InterfaceInfo(model_providers=model_providers)

tests/config/test_default_model_settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from data_designer.config.default_model_settings import (
1212
get_builtin_model_configs,
1313
get_builtin_model_providers,
14+
get_defaul_model_providers_missing_api_keys,
1415
get_default_inference_parameters,
1516
get_default_model_configs,
1617
get_default_provider_name,
@@ -146,3 +147,9 @@ def test_resolve_seed_default_model_settings(tmp_path: Path):
146147
with open(model_providers_file_path) as f:
147148
providers_data = yaml.safe_load(f)
148149
assert providers_data == {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
150+
151+
152+
@patch("data_designer.config.default_model_settings.os.environ.get")
153+
def test_get_default_model_providers_missing_api_keys(mock_environ_get):
154+
mock_environ_get.return_value = None
155+
assert get_defaul_model_providers_missing_api_keys() == ["NVIDIA_API_KEY", "OPENAI_API_KEY"]

0 commit comments

Comments
 (0)