diff --git a/src/data_designer/config/base.py b/src/data_designer/config/base.py index 64d4d6de..5dfe6d8b 100644 --- a/src/data_designer/config/base.py +++ b/src/data_designer/config/base.py @@ -3,49 +3,14 @@ from __future__ import annotations -from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union +from typing import Any, Optional, Union -import pandas as pd from pydantic import BaseModel, ConfigDict import yaml from .utils.io_helpers import serialize_data -if TYPE_CHECKING: - from .analysis.dataset_profiler import DatasetProfilerResults - from .config_builder import DataDesignerConfigBuilder - from .preview_results import PreviewResults - -DEFAULT_NUM_RECORDS = 10 - - -class ResultsProtocol(Protocol): - def load_analysis(self) -> DatasetProfilerResults: ... - def load_dataset(self) -> pd.DataFrame: ... - - -ResultsT = TypeVar("ResultsT", bound=ResultsProtocol) - - -class DataDesignerInterface(ABC, Generic[ResultsT]): - @abstractmethod - def create( - self, - config_builder: DataDesignerConfigBuilder, - *, - num_records: int = DEFAULT_NUM_RECORDS, - ) -> ResultsT: ... - - @abstractmethod - def preview( - self, - config_builder: DataDesignerConfigBuilder, - *, - num_records: int = DEFAULT_NUM_RECORDS, - ) -> PreviewResults: ... - class ConfigBase(BaseModel): model_config = ConfigDict( diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 78cfe724..eca394cf 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -44,7 +44,7 @@ SeedDatasetReference, ) from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE -from .utils.info import DataDesignerInfo +from .utils.info import ConfigBuilderInfo from .utils.io_helpers import serialize_data, smart_load_yaml from .utils.misc import ( can_run_data_designer_locally, @@ -132,14 +132,13 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self: return builder - def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None): + def __init__(self, model_configs: Union[list[ModelConfig], str, Path]): """Initialize a new DataDesignerConfigBuilder instance. Args: - model_configs: Optional model configurations. Can be: + model_configs: Model configurations. Can be: - A list of ModelConfig objects - A string or Path to a model configuration file - - None to use default model configurations """ self._column_configs = {} self._model_configs = load_model_configs(model_configs) @@ -147,7 +146,6 @@ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] self._seed_config: Optional[SeedConfig] = None self._constraints: list[ColumnConstraintT] = [] self._profilers: list[ColumnProfilerConfigT] = [] - self._info = DataDesignerInfo() self._datastore_settings: Optional[DatastoreSettings] = None @property @@ -173,13 +171,13 @@ def allowed_references(self) -> list[str]: return list(self._column_configs.keys()) + list(set(side_effect_columns)) @property - def info(self) -> DataDesignerInfo: - """Get the DataDesignerInfo object for this builder. + def info(self) -> ConfigBuilderInfo: + """Get the ConfigBuilderInfo object for this builder. Returns: - An object containing metadata about the configuration. + An object containing information about the configuration. """ - return self._info + return ConfigBuilderInfo(model_configs=self._model_configs) def add_model_config(self, model_config: ModelConfig) -> Self: """Add a model configuration to the current Data Designer configuration. diff --git a/src/data_designer/config/interface.py b/src/data_designer/config/interface.py new file mode 100644 index 00000000..fa8e17b4 --- /dev/null +++ b/src/data_designer/config/interface.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar + +import pandas as pd + +from .models import ModelConfig, ModelProvider +from .utils.constants import DEFAULT_NUM_RECORDS +from .utils.info import InterfaceInfo + +if TYPE_CHECKING: + from .analysis.dataset_profiler import DatasetProfilerResults + from .config_builder import DataDesignerConfigBuilder + from .preview_results import PreviewResults + + +class ResultsProtocol(Protocol): + def load_analysis(self) -> DatasetProfilerResults: ... + def load_dataset(self) -> pd.DataFrame: ... + + +ResultsT = TypeVar("ResultsT", bound=ResultsProtocol) + + +class DataDesignerInterface(ABC, Generic[ResultsT]): + @abstractmethod + def create( + self, + config_builder: DataDesignerConfigBuilder, + *, + num_records: int = DEFAULT_NUM_RECORDS, + ) -> ResultsT: ... + + @abstractmethod + def preview( + self, + config_builder: DataDesignerConfigBuilder, + *, + num_records: int = DEFAULT_NUM_RECORDS, + ) -> PreviewResults: ... + + @abstractmethod + def get_default_model_configs(self) -> list[ModelConfig]: ... + + @abstractmethod + def get_default_model_providers(self) -> list[ModelProvider]: ... + + @property + @abstractmethod + def info(self) -> InterfaceInfo: ... diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 4fedb0c9..17b61575 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod from enum import Enum +import logging +import os from pathlib import Path from typing import Any, Generic, List, Optional, TypeVar, Union @@ -12,9 +14,20 @@ from .base import ConfigBase from .errors import InvalidConfigError -from .utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P +from .utils.constants import ( + MAX_TEMPERATURE, + MAX_TOP_P, + MIN_TEMPERATURE, + MIN_TOP_P, + NVIDIA_API_KEY_ENV_VAR_NAME, + NVIDIA_PROVIDER_NAME, + OPENAI_API_KEY_ENV_VAR_NAME, + OPENAI_PROVIDER_NAME, +) from .utils.io_helpers import smart_load_yaml +logger = logging.getLogger(__name__) + class Modality(str, Enum): IMAGE = "image" @@ -204,9 +217,14 @@ class ModelConfig(ConfigBase): provider: Optional[str] = None -def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]: - if model_configs is None: - return [] +class ModelProvider(ConfigBase): + name: str + endpoint: str + provider_type: str = "openai" + api_key: str | None = None + + +def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]: if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs): return model_configs json_config = smart_load_yaml(model_configs) @@ -215,3 +233,107 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) "The list of model configs must be provided under model_configs in the configuration file." ) return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]] + + +def get_default_text_alias_inference_parameters() -> InferenceParameters: + return InferenceParameters( + temperature=0.85, + top_p=0.95, + ) + + +def get_default_reasoning_alias_inference_parameters() -> InferenceParameters: + return InferenceParameters( + temperature=0.35, + top_p=0.95, + ) + + +def get_default_vision_alias_inference_parameters() -> InferenceParameters: + return InferenceParameters( + temperature=0.85, + top_p=0.95, + ) + + +def get_default_nvidia_model_configs() -> list[ModelConfig]: + if not get_nvidia_api_key(): + logger.warning( + f"🔑 {NVIDIA_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs." + ) + return [] + return [ + ModelConfig( + alias=f"{NVIDIA_PROVIDER_NAME}-text", + model="nvidia/nvidia-nemotron-nano-9b-v2", + provider=NVIDIA_PROVIDER_NAME, + inference_parameters=get_default_text_alias_inference_parameters(), + ), + ModelConfig( + alias=f"{NVIDIA_PROVIDER_NAME}-reasoning", + model="openai/gpt-oss-20b", + provider=NVIDIA_PROVIDER_NAME, + inference_parameters=get_default_reasoning_alias_inference_parameters(), + ), + ModelConfig( + alias=f"{NVIDIA_PROVIDER_NAME}-vision", + model="nvidia/nemotron-nano-12b-v2-vl", + provider=NVIDIA_PROVIDER_NAME, + inference_parameters=get_default_vision_alias_inference_parameters(), + ), + ] + + +def get_default_openai_model_configs() -> list[ModelConfig]: + if not get_openai_api_key(): + logger.warning( + f"🔑 {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs." + ) + return [] + return [ + ModelConfig( + alias=f"{OPENAI_PROVIDER_NAME}-text", + model="gpt-4.1", + provider=OPENAI_PROVIDER_NAME, + inference_parameters=get_default_text_alias_inference_parameters(), + ), + ModelConfig( + alias=f"{OPENAI_PROVIDER_NAME}-reasoning", + model="gpt-5", + provider=OPENAI_PROVIDER_NAME, + inference_parameters=get_default_reasoning_alias_inference_parameters(), + ), + ModelConfig( + alias=f"{OPENAI_PROVIDER_NAME}-vision", + model="gpt-5", + provider=OPENAI_PROVIDER_NAME, + inference_parameters=get_default_vision_alias_inference_parameters(), + ), + ] + + +def get_default_model_configs() -> list[ModelConfig]: + return get_default_nvidia_model_configs() + get_default_openai_model_configs() + + +def get_default_providers() -> list[ModelProvider]: + return [ + ModelProvider( + name=NVIDIA_PROVIDER_NAME, + endpoint="https://integrate.api.nvidia.com/v1", + api_key=NVIDIA_API_KEY_ENV_VAR_NAME, + ), + ModelProvider( + name=OPENAI_PROVIDER_NAME, + endpoint="https://api.openai.com/v1", + api_key=OPENAI_API_KEY_ENV_VAR_NAME, + ), + ] + + +def get_nvidia_api_key() -> Optional[str]: + return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME) + + +def get_openai_api_key() -> Optional[str]: + return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME) diff --git a/src/data_designer/config/utils/constants.py b/src/data_designer/config/utils/constants.py index 7dd6f41a..4756da9c 100644 --- a/src/data_designer/config/utils/constants.py +++ b/src/data_designer/config/utils/constants.py @@ -5,6 +5,8 @@ from rich.theme import Theme +DEFAULT_NUM_RECORDS = 10 + EPSILON = 1e-8 REPORTING_PRECISION = 2 @@ -255,3 +257,9 @@ class NordColor(Enum): "zh_TW", "zu_ZA", ] + +NVIDIA_PROVIDER_NAME = "nvidia" +NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY" + +OPENAI_PROVIDER_NAME = "openai" +OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" diff --git a/src/data_designer/config/utils/info.py b/src/data_designer/config/utils/info.py index 482d87cc..cf9765a3 100644 --- a/src/data_designer/config/utils/info.py +++ b/src/data_designer/config/utils/info.py @@ -1,23 +1,88 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from enum import Enum +from typing import Literal, TypeVar + +from ..models import ModelConfig, ModelProvider from ..sampler_params import SamplerType from .type_helpers import get_sampler_params -from .visualization import display_sampler_table +from .visualization import display_model_configs_table, display_model_providers_table, display_sampler_table + + +class InfoType(str, Enum): + SAMPLERS = "SAMPLERS" + MODEL_CONFIGS = "MODEL_CONFIGS" + MODEL_PROVIDERS = "MODEL_PROVIDERS" + + +ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS] +DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS] +InfoTypeT = TypeVar("InfoTypeT", bound=InfoType) + + +class InfoDisplay(ABC): + """Base class for info display classes that provide type-safe display methods.""" + + @abstractmethod + def display(self, info_type: InfoTypeT, **kwargs) -> None: + """Display information based on the provided info type. + Args: + info_type: Type of information to display. + """ + ... -class DataDesignerInfo: - def __init__(self): + +class ConfigBuilderInfo(InfoDisplay): + def __init__(self, model_configs: list[ModelConfig]): self._sampler_params = get_sampler_params() + self._model_configs = model_configs + + def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None: + """Display information based on the provided info type. + + Args: + info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported. + + Raises: + ValueError: If an unsupported info_type is provided. + """ + if info_type == InfoType.SAMPLERS: + self._display_sampler_info(sampler_type=kwargs.get("sampler_type")) + elif info_type == InfoType.MODEL_CONFIGS: + display_model_configs_table(self._model_configs) + else: + raise ValueError( + f"Unsupported info_type: {str(info_type)!r}. " + f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}." + ) + + def _display_sampler_info(self, sampler_type: SamplerType | None) -> None: + if sampler_type is not None: + title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler" + display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title) + else: + display_sampler_table(self._sampler_params) + + +class InterfaceInfo(InfoDisplay): + def __init__(self, model_providers: list[ModelProvider]): + self._model_providers = model_providers - @property - def sampler_table(self) -> None: - display_sampler_table(self._sampler_params) + def display(self, info_type: DataDesignerInfoType, **kwargs) -> None: + """Display information based on the provided info type. - @property - def sampler_types(self) -> list[str]: - return [s.value for s in SamplerType] + Args: + info_type: Type of information to display. Only MODEL_PROVIDERS is supported. - def display_sampler(self, sampler_type: SamplerType) -> None: - title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler" - display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title) + Raises: + ValueError: If an unsupported info_type is provided. + """ + if info_type == InfoType.MODEL_PROVIDERS: + display_model_providers_table(self._model_providers) + else: + raise ValueError( + f"Unsupported info_type: {str(info_type)!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}." + ) diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index f245517c..4b245369 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -22,8 +22,10 @@ from ..base import ConfigBase from ..columns import DataDesignerColumnType +from ..models import ModelConfig, ModelProvider, get_nvidia_api_key, get_openai_api_key from ..sampler_params import SamplerType from .code_lang import code_lang_to_syntax_lexer +from .constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME from .errors import DatasetSampleDisplayError if TYPE_CHECKING: @@ -258,6 +260,55 @@ def display_sampler_table( console.print(group) +def display_model_configs_table(model_configs: list[ModelConfig]) -> None: + table_model_configs = Table(expand=True) + table_model_configs.add_column("Alias") + table_model_configs.add_column("Model") + table_model_configs.add_column("Provider") + table_model_configs.add_column("Temperature") + table_model_configs.add_column("Top P") + for model_config in model_configs: + table_model_configs.add_row( + model_config.alias, + model_config.model, + model_config.provider, + str(model_config.inference_parameters.temperature), + str(model_config.inference_parameters.top_p), + ) + group_args: list = [Rule(title="Model Configs"), table_model_configs] + if len(model_configs) == 0: + subtitle = Text( + "‼️ No model configs found. Please provide at least one model config to the config builder", + style="dim", + justify="center", + ) + group_args.insert(1, subtitle) + group = Group(*group_args) + console.print(group) + + +def display_model_providers_table(model_providers: list[ModelProvider]) -> None: + table_model_providers = Table(expand=True) + table_model_providers.add_column("Name") + table_model_providers.add_column("Endpoint") + table_model_providers.add_column("API Key") + for model_provider in model_providers: + api_key = model_provider.api_key + if model_provider.api_key == OPENAI_API_KEY_ENV_VAR_NAME: + if get_openai_api_key() is not None: + api_key = get_openai_api_key()[:1] + "********" + else: + api_key = f"* {OPENAI_API_KEY_ENV_VAR_NAME!r} not set in environment variables * " + elif model_provider.api_key == NVIDIA_API_KEY_ENV_VAR_NAME: + if get_nvidia_api_key() is not None: + api_key = get_nvidia_api_key()[:1] + "********" + else: + api_key = f"* {NVIDIA_API_KEY_ENV_VAR_NAME!r} not set in environment variables *" + table_model_providers.add_row(model_provider.name, model_provider.endpoint, api_key) + group = Group(Rule(title="Model Providers"), table_model_providers) + console.print(group) + + def convert_to_row_element(elem): try: elem = Pretty(json.loads(elem)) diff --git a/src/data_designer/engine/model_provider.py b/src/data_designer/engine/model_provider.py index e25f2e50..9e0680a6 100644 --- a/src/data_designer/engine/model_provider.py +++ b/src/data_designer/engine/model_provider.py @@ -6,16 +6,10 @@ from pydantic import BaseModel, field_validator, model_validator +from data_designer.config.models import ModelProvider from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError -class ModelProvider(BaseModel): - name: str - endpoint: str - provider_type: str = "openai" - api_key: str | None = None - - class ModelProviderRegistry(BaseModel): providers: list[ModelProvider] default: str | None = None @@ -70,21 +64,10 @@ def get_provider(self, name: str | None) -> ModelProvider: raise UnknownProviderError(f"No provider named {name!r} registered") -def resolve_model_provider_registry(model_providers: list[ModelProvider] | None = None) -> ModelProviderRegistry: - if model_providers: - if len(model_providers) == 0: - raise NoModelProvidersError("At least one model provider must be defined") - return ModelProviderRegistry( - providers=model_providers, - default=model_providers[0].name, - ) +def resolve_model_provider_registry(model_providers: list[ModelProvider]) -> ModelProviderRegistry: + if len(model_providers) == 0: + raise NoModelProvidersError("At least one model provider must be defined") return ModelProviderRegistry( - providers=[ - ModelProvider( - name="nvidia", - endpoint="https://integrate.api.nvidia.com/v1", - api_key="NVIDIA_API_KEY", - ) - ], - default="nvidia", + providers=model_providers, + default=model_providers[0].name, ) diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py index 52e1f23a..ef0590b1 100644 --- a/src/data_designer/engine/models/registry.py +++ b/src/data_designer/engine/models/registry.py @@ -71,7 +71,9 @@ def get_model_provider(self, *, model_alias: str) -> ModelProvider: def run_health_check(self) -> None: logger.info("🩺 Running health checks for models...") for model in self._models.values(): - logger.info(f" |-- 👀 Checking '{model.model_name}'...") + logger.info( + f" |-- 👀 Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..." + ) try: model.generate( prompt="Hello!", diff --git a/src/data_designer/engine/secret_resolver.py b/src/data_designer/engine/secret_resolver.py index 521064d4..2771544f 100644 --- a/src/data_designer/engine/secret_resolver.py +++ b/src/data_designer/engine/secret_resolver.py @@ -39,7 +39,9 @@ def resolve(self, secret: str) -> str: try: return os.environ[secret] except KeyError: - raise SecretResolutionError(f"No env var found with name {secret!r}") + raise SecretResolutionError( + f"Environment variable with name {secret!r} is required but not set. Please set it in your environment and try again." + ) class PlaintextResolver(SecretResolver): diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 3b3fd96f..2d6516c6 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -51,6 +51,7 @@ ) from ..config.seed import DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig from ..config.utils.code_lang import CodeLang +from ..config.utils.info import InfoType from ..config.utils.misc import can_run_data_designer_locally from ..config.validator_params import ( CodeValidatorParams, @@ -90,6 +91,7 @@ "ExpressionColumnConfig", "GaussianSamplerParams", "IndexRange", + "InfoType", "ImageContext", "ImageFormat", "InferenceParameters", diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py index 13d16cb5..35ea9622 100644 --- a/src/data_designer/interface/data_designer.py +++ b/src/data_designer/interface/data_designer.py @@ -7,10 +7,22 @@ import pandas as pd from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults -from data_designer.config.base import DEFAULT_NUM_RECORDS, DataDesignerInterface from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.interface import DataDesignerInterface +from data_designer.config.models import ( + ModelConfig, + ModelProvider, + get_default_model_configs, + get_default_providers, +) from data_designer.config.preview_results import PreviewResults from data_designer.config.seed import LocalSeedDatasetReference +from data_designer.config.utils.constants import ( + DEFAULT_NUM_RECORDS, + NVIDIA_API_KEY_ENV_VAR_NAME, + OPENAI_API_KEY_ENV_VAR_NAME, +) +from data_designer.config.utils.info import InterfaceInfo from data_designer.config.utils.io_helpers import write_seed_dataset from data_designer.engine.analysis.dataset_profiler import ( DataDesignerDatasetProfiler, @@ -19,7 +31,7 @@ from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs -from data_designer.engine.model_provider import ModelProvider, resolve_model_provider_registry +from data_designer.engine.model_provider import resolve_model_provider_registry from data_designer.engine.models.registry import create_model_registry from data_designer.engine.resources.managed_storage import init_managed_blob_storage from data_designer.engine.resources.resource_provider import ResourceProvider @@ -62,21 +74,22 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]): def __init__( self, - artifact_path: Path | str, + artifact_path: Path | str | None = None, *, model_providers: list[ModelProvider] | None = None, secret_resolver: SecretResolver = EnvironmentResolver(), blob_storage_path: Path | str | None = None, ): self._secret_resolver = secret_resolver - self._artifact_path = Path(artifact_path) + self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts" self._buffer_size = DEFAULT_BUFFER_SIZE self._blob_storage = ( init_managed_blob_storage() if blob_storage_path is None else init_managed_blob_storage(str(blob_storage_path)) ) - self._model_provider_registry = resolve_model_provider_registry(model_providers) + self._model_providers = model_providers or self.get_default_model_providers() + self._model_provider_registry = resolve_model_provider_registry(self._model_providers) @staticmethod def make_seed_reference_from_file(file_path: str | Path) -> LocalSeedDatasetReference: @@ -114,6 +127,10 @@ def make_seed_reference_from_dataframe( write_seed_dataset(dataframe, Path(file_path)) return cls.make_seed_reference_from_file(file_path) + @property + def info(self) -> InterfaceInfo: + return InterfaceInfo(model_providers=self._model_providers) + def create( self, config_builder: DataDesignerConfigBuilder, @@ -213,6 +230,17 @@ def preview( config_builder=config_builder, ) + def get_default_model_configs(self) -> list[ModelConfig]: + model_configs = get_default_model_configs() + if len(model_configs) == 0: + logger.warning( + f"‼️ Neither {NVIDIA_API_KEY_ENV_VAR_NAME!r} nor {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variables are set. Please set at least one of them if you want to use the default model configs." + ) + return model_configs + + def get_default_model_providers(self) -> list[ModelProvider]: + return get_default_providers() + def set_buffer_size(self, buffer_size: int) -> None: """Set the buffer size for dataset generation. diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 68a78e5f..76ce7bc7 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -32,6 +32,7 @@ from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy from data_designer.config.utils.code_lang import CodeLang +from data_designer.config.utils.info import ConfigBuilderInfo from data_designer.config.validator_params import CodeValidatorParams @@ -97,7 +98,7 @@ def test_from_config(stub_data_designer_builder_config_str, mock_fetch_seed_data def test_info(stub_data_designer_builder): assert stub_data_designer_builder.info is not None - assert isinstance(stub_data_designer_builder.info.sampler_types, list) + assert isinstance(stub_data_designer_builder.info, ConfigBuilderInfo) def test_add_column_with_types(stub_empty_builder): diff --git a/tests/config/test_models.py b/tests/config/test_models.py index 812c01bb..f0c19714 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -208,7 +208,6 @@ def test_load_model_configs(): ] stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs] assert load_model_configs([]) == [] - assert load_model_configs(None) == [] assert load_model_configs(stub_model_configs) == stub_model_configs with tempfile.NamedTemporaryFile(prefix="model_configs", suffix=".yaml") as tmp_file: diff --git a/tests/config/utils/test_info.py b/tests/config/utils/test_info.py index deec7eb2..b8d4b593 100644 --- a/tests/config/utils/test_info.py +++ b/tests/config/utils/test_info.py @@ -1,33 +1,58 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 - from unittest.mock import patch -from data_designer.config.sampler_params import BernoulliSamplerParams, BinomialSamplerParams, SamplerType -from data_designer.config.utils.info import DataDesignerInfo +import pytest + +from data_designer.config.sampler_params import SamplerType +from data_designer.config.utils.info import ConfigBuilderInfo, InfoType, InterfaceInfo +from data_designer.config.utils.type_helpers import get_sampler_params @patch("data_designer.config.utils.info.display_sampler_table") -@patch("data_designer.config.utils.info.get_sampler_params") -def test_data_designer_info(mock_get_sampler_params, mock_display_sampler_table): - stub_bernoulli_params = BernoulliSamplerParams(p=0.5) - stub_binomial_params = BinomialSamplerParams(n=100, p=0.5) - mock_get_sampler_params.return_value = { - SamplerType.BERNOULLI: stub_bernoulli_params, - SamplerType.BINOMIAL: stub_binomial_params, - } - info = DataDesignerInfo() - - assert SamplerType.BINOMIAL.value in info.sampler_types - mock_get_sampler_params.assert_called_once() - - _ = info.sampler_table - mock_display_sampler_table.assert_called_once_with( - {SamplerType.BERNOULLI: stub_bernoulli_params, SamplerType.BINOMIAL: stub_binomial_params} - ) +@patch("data_designer.config.utils.info.display_model_configs_table") +def test_config_builder_sampler_info(mock_display_model_configs_table, mock_display_sampler_table, stub_model_configs): + info = ConfigBuilderInfo(model_configs=stub_model_configs) + info.display(InfoType.MODEL_CONFIGS) + mock_display_model_configs_table.assert_called_once_with(stub_model_configs) + + sampler_params = get_sampler_params() + info.display(InfoType.SAMPLERS) + mock_display_sampler_table.assert_called_once_with(sampler_params) mock_display_sampler_table.reset_mock() - info.display_sampler(SamplerType.BERNOULLI) + info.display(InfoType.SAMPLERS, sampler_type=SamplerType.BERNOULLI) mock_display_sampler_table.assert_called_once_with( - {SamplerType.BERNOULLI: stub_bernoulli_params}, title="Bernoulli Sampler" + {SamplerType.BERNOULLI: sampler_params[SamplerType.BERNOULLI]}, title="Bernoulli Sampler" ) + + +@patch("data_designer.config.utils.info.display_model_configs_table") +def test_config_builder_model_configs_info(mock_display_model_configs_table, stub_model_configs): + info = ConfigBuilderInfo(model_configs=stub_model_configs) + info.display(InfoType.MODEL_CONFIGS) + mock_display_model_configs_table.assert_called_once_with(stub_model_configs) + + +def test_config_builder_unsupported_info_type(stub_model_configs): + info = ConfigBuilderInfo(model_configs=stub_model_configs) + with pytest.raises( + ValueError, + match="Unsupported info_type: 'unsupported_type'. ConfigBuilderInfo only supports 'SAMPLERS' and 'MODEL_CONFIGS'.", + ): + info.display("unsupported_type") + + +@patch("data_designer.config.utils.info.display_model_providers_table") +def test_interface_model_providers_info(mock_display_model_providers_table, stub_model_providers): + info = InterfaceInfo(model_providers=stub_model_providers) + info.display(InfoType.MODEL_PROVIDERS) + mock_display_model_providers_table.assert_called_once_with(stub_model_providers) + + +def test_interface_unsupported_info_type(stub_model_providers): + info = InterfaceInfo(model_providers=stub_model_providers) + with pytest.raises( + ValueError, match="Unsupported info_type: 'unsupported_type'. InterfaceInfo only supports 'MODEL_PROVIDERS'." + ): + info.display("unsupported_type") diff --git a/tests/config/utils/test_visualization.py b/tests/config/utils/test_visualization.py index aaa3201a..1cdc1a05 100644 --- a/tests/config/utils/test_visualization.py +++ b/tests/config/utils/test_visualization.py @@ -24,12 +24,12 @@ def validation_output(): @pytest.fixture -def config_builder_with_validation(): +def config_builder_with_validation(stub_model_configs): """Fixture providing a DataDesignerConfigBuilder with a validation column.""" with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: mock_fetch.return_value = ["code"] - builder = DataDesignerConfigBuilder() + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) # Add a validation column configuration builder.add_column( diff --git a/tests/conftest.py b/tests/conftest.py index 62ecc60a..f0e6546e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider @pytest.fixture @@ -144,6 +144,17 @@ def stub_model_configs() -> list[ModelConfig]: ] +@pytest.fixture +def stub_model_providers() -> list[ModelProvider]: + return [ + ModelProvider( + name="provider-1", + endpoint="https://api.provider-1.com/v1", + api_key="PROVIDER_1_API_KEY", + ) + ] + + @pytest.fixture def stub_empty_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerConfigBuilder: """Test builder with model configs.""" diff --git a/tests/engine/test_secret_resolver.py b/tests/engine/test_secret_resolver.py index ed20eb61..02862900 100644 --- a/tests/engine/test_secret_resolver.py +++ b/tests/engine/test_secret_resolver.py @@ -85,5 +85,5 @@ def test_composite_resolver_error(stub_secrets_file: Path): resolver.resolve("QUUX") # The composite error message aggregates the individual resolvers' error messages - assert "env var" in str(excinfo.value) + assert "Environment variable" in str(excinfo.value) assert "secret" in str(excinfo.value)