|
9 | 9 | from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults |
10 | 10 | from data_designer.config.config_builder import DataDesignerConfigBuilder |
11 | 11 | from data_designer.config.default_model_settings import ( |
| 12 | + get_defaul_model_providers_missing_api_keys, |
12 | 13 | get_default_model_configs, |
13 | 14 | get_default_provider_name, |
14 | 15 | get_default_providers, |
|
26 | 27 | MANAGED_ASSETS_PATH, |
27 | 28 | MODEL_CONFIGS_FILE_PATH, |
28 | 29 | MODEL_PROVIDERS_FILE_PATH, |
| 30 | + PREDEFINED_PROVIDERS, |
29 | 31 | ) |
30 | | -from data_designer.config.utils.info import InterfaceInfo |
| 32 | +from data_designer.config.utils.info import InfoType, InterfaceInfo |
31 | 33 | from data_designer.config.utils.io_helpers import write_seed_dataset |
32 | 34 | from data_designer.config.utils.misc import can_run_data_designer_locally |
33 | 35 | from data_designer.engine.analysis.dataset_profiler import ( |
@@ -103,7 +105,7 @@ def __init__( |
103 | 105 | self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts" |
104 | 106 | self._buffer_size = DEFAULT_BUFFER_SIZE |
105 | 107 | self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH) |
106 | | - self._model_providers = model_providers or self.get_default_model_providers() |
| 108 | + self._model_providers = self._resolve_model_providers(model_providers) |
107 | 109 | self._model_provider_registry = resolve_model_provider_registry( |
108 | 110 | self._model_providers, get_default_provider_name() |
109 | 111 | ) |
@@ -151,7 +153,7 @@ def info(self) -> InterfaceInfo: |
151 | 153 | Returns: |
152 | 154 | InterfaceInfo object with information about the Data Designer interface. |
153 | 155 | """ |
154 | | - return InterfaceInfo(model_providers=self._model_providers) |
| 156 | + return self._get_interface_info(self._model_providers) |
155 | 157 |
|
156 | 158 | def create( |
157 | 159 | self, |
@@ -307,6 +309,22 @@ def set_buffer_size(self, buffer_size: int) -> None: |
307 | 309 | raise InvalidBufferValueError("Buffer size must be greater than 0.") |
308 | 310 | self._buffer_size = buffer_size |
309 | 311 |
|
| 312 | + def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) -> list[ModelProvider]: |
| 313 | + if model_providers is None: |
| 314 | + if can_run_data_designer_locally(): |
| 315 | + model_providers = get_default_providers() |
| 316 | + missing_api_keys = get_defaul_model_providers_missing_api_keys() |
| 317 | + if len(missing_api_keys) == len(PREDEFINED_PROVIDERS): |
| 318 | + logger.warning( |
| 319 | + "🚨 You are trying to use a default model provider but your API keys are missing." |
| 320 | + "\n\t\t\tSet the API key for the default providers you intend to use and re-initialize the Data Designer object." |
| 321 | + "\n\t\t\tAlternatively, you can provide your own model providers during Data Designer object initialization." |
| 322 | + "\n\t\t\tSee https://nvidia-nemo.github.io/DataDesigner/models/model-providers/ for more information." |
| 323 | + ) |
| 324 | + self._get_interface_info(model_providers).display(InfoType.MODEL_PROVIDERS) |
| 325 | + return model_providers |
| 326 | + return model_providers or [] |
| 327 | + |
310 | 328 | def _create_dataset_builder( |
311 | 329 | self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider |
312 | 330 | ) -> ColumnWiseDatasetBuilder: |
@@ -349,3 +367,6 @@ def _create_resource_provider( |
349 | 367 | ) |
350 | 368 | ), |
351 | 369 | ) |
| 370 | + |
| 371 | + def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo: |
| 372 | + return InterfaceInfo(model_providers=model_providers) |
0 commit comments