Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions src/data_designer/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,14 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union
from typing import Any, Optional, Union

import pandas as pd
from pydantic import BaseModel, ConfigDict
import yaml

from .utils.io_helpers import serialize_data

if TYPE_CHECKING:
from .analysis.dataset_profiler import DatasetProfilerResults
from .config_builder import DataDesignerConfigBuilder
from .preview_results import PreviewResults

DEFAULT_NUM_RECORDS = 10


class ResultsProtocol(Protocol):
def load_analysis(self) -> DatasetProfilerResults: ...
def load_dataset(self) -> pd.DataFrame: ...


ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)


class DataDesignerInterface(ABC, Generic[ResultsT]):
@abstractmethod
def create(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> ResultsT: ...

@abstractmethod
def preview(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> PreviewResults: ...


class ConfigBase(BaseModel):
model_config = ConfigDict(
Expand Down
16 changes: 7 additions & 9 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
SeedDatasetReference,
)
from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
from .utils.info import DataDesignerInfo
from .utils.info import ConfigBuilderInfo
from .utils.io_helpers import serialize_data, smart_load_yaml
from .utils.misc import (
can_run_data_designer_locally,
Expand Down Expand Up @@ -132,22 +132,20 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:

return builder

def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None):
def __init__(self, model_configs: Union[list[ModelConfig], str, Path]):
"""Initialize a new DataDesignerConfigBuilder instance.

Args:
model_configs: Optional model configurations. Can be:
model_configs: Model configurations. Can be:
- A list of ModelConfig objects
- A string or Path to a model configuration file
- None to use default model configurations
"""
self._column_configs = {}
self._model_configs = load_model_configs(model_configs)
self._processor_configs: list[ProcessorConfig] = []
self._seed_config: Optional[SeedConfig] = None
self._constraints: list[ColumnConstraintT] = []
self._profilers: list[ColumnProfilerConfigT] = []
self._info = DataDesignerInfo()
self._datastore_settings: Optional[DatastoreSettings] = None

@property
Expand All @@ -173,13 +171,13 @@ def allowed_references(self) -> list[str]:
return list(self._column_configs.keys()) + list(set(side_effect_columns))

@property
def info(self) -> DataDesignerInfo:
"""Get the DataDesignerInfo object for this builder.
def info(self) -> ConfigBuilderInfo:
"""Get the ConfigBuilderInfo object for this builder.

Returns:
An object containing metadata about the configuration.
An object containing information about the configuration.
"""
return self._info
return ConfigBuilderInfo(model_configs=self._model_configs)

def add_model_config(self, model_config: ModelConfig) -> Self:
"""Add a model configuration to the current Data Designer configuration.
Expand Down
54 changes: 54 additions & 0 deletions src/data_designer/config/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar

import pandas as pd

from .models import ModelConfig, ModelProvider
from .utils.constants import DEFAULT_NUM_RECORDS
from .utils.info import InterfaceInfo

if TYPE_CHECKING:
from .analysis.dataset_profiler import DatasetProfilerResults
from .config_builder import DataDesignerConfigBuilder
from .preview_results import PreviewResults


class ResultsProtocol(Protocol):
def load_analysis(self) -> DatasetProfilerResults: ...
def load_dataset(self) -> pd.DataFrame: ...


ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)


class DataDesignerInterface(ABC, Generic[ResultsT]):
@abstractmethod
def create(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> ResultsT: ...

@abstractmethod
def preview(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> PreviewResults: ...

@abstractmethod
def get_default_model_configs(self) -> list[ModelConfig]: ...

@abstractmethod
def get_default_model_providers(self) -> list[ModelProvider]: ...

@property
@abstractmethod
def info(self) -> InterfaceInfo: ...
Comment on lines +46 to +54
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikeknep just want to make sure you are are of these changes for the nmp dd client.

72 changes: 68 additions & 4 deletions src/data_designer/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from abc import ABC, abstractmethod
from enum import Enum
import logging
import os
from pathlib import Path
from typing import Any, Generic, List, Optional, TypeVar, Union

Expand All @@ -12,9 +14,18 @@

from .base import ConfigBase
from .errors import InvalidConfigError
from .utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
from .utils.constants import (
MAX_TEMPERATURE,
MAX_TOP_P,
MIN_TEMPERATURE,
MIN_TOP_P,
NVIDIA_API_KEY_ENV_VAR_NAME,
NVIDIA_PROVIDER_NAME,
)
from .utils.io_helpers import smart_load_yaml

logger = logging.getLogger(__name__)


class Modality(str, Enum):
IMAGE = "image"
Expand Down Expand Up @@ -204,9 +215,14 @@ class ModelConfig(ConfigBase):
provider: Optional[str] = None


def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]:
if model_configs is None:
return []
class ModelProvider(ConfigBase):
name: str
endpoint: str
provider_type: str = "openai"
api_key: str | None = None


def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
return model_configs
json_config = smart_load_yaml(model_configs)
Expand All @@ -215,3 +231,51 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None])
"The list of model configs must be provided under model_configs in the configuration file."
)
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]


def get_default_nvidia_model_configs() -> list[ModelConfig]:
if not get_nvidia_api_key():
logger.warning(
"‼️🔑 'NVIDIA_API_KEY' environment variable is not set. Please set it to your API key from 'build.nvidia.com' if you want to use the default NVIDIA model configs."
)
return [
ModelConfig(
alias="text",
model="nvidia/nvidia-nemotron-nano-9b-v2",
provider=NVIDIA_PROVIDER_NAME,
inference_parameters=InferenceParameters(
temperature=0.85,
top_p=0.95,
),
),
ModelConfig(
alias="reasoning",
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kirit93 – please chime in with thoughts on default models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We all chatted offline to support multiple provider (nvidia, opena) to start with along with their model configs. It' all handled in this commit:

Behaviors:

  1. This displays the default model providers (nvidia and and openai). If the required API key cannot be found, it will show it's missing.
data_designer = DataDesigner()
data_designer.info.display(InfoType.MODEL_PROVIDERS)
Screenshot 2025-11-10 at 3 42 34 PM
  1. This displays the available default model configs based on which API keys were detected. In this example, only openai model configs were shown because nvidia api key could not be found. There's a warning that gets printed out as well.
config_builder = DataDesignerConfigBuilder(model_configs=data_designer.get_default_model_configs())
config_builder.info.display(InfoType.MODEL_CONFIGS)
Screenshot 2025-11-10 at 3 45 16 PM
  1. If both nvidia and openai models cannot be found, we'll see warnings for each of the keys not being found + additional warnings.
Screenshot 2025-11-10 at 3 48 01 PM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnnygreco if you are cool with the above, let's get this merged in. We can further tweak the model names, etc in follow up PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sounds good. We can tweak in follow ups. One thought I had was to have get_default_model_configs take an argument ("nvidia" or "openai"), which we can default to "nvidia". That way we can only raise a warning for the appropriate API key when needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea sounds good I'll get that in the next PR.

provider=NVIDIA_PROVIDER_NAME,
inference_parameters=InferenceParameters(
temperature=0.35,
top_p=0.95,
),
),
ModelConfig(
alias="vision",
model="nvidia/nemotron-nano-12b-v2-vl",
provider=NVIDIA_PROVIDER_NAME,
inference_parameters=InferenceParameters(
temperature=0.85,
top_p=0.95,
),
),
]


def get_nvidia_api_key() -> Optional[str]:
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)


def get_default_nvidia_model_provider() -> ModelProvider:
return ModelProvider(
name=NVIDIA_PROVIDER_NAME,
endpoint="https://integrate.api.nvidia.com/v1",
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
)
5 changes: 5 additions & 0 deletions src/data_designer/config/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from rich.theme import Theme

DEFAULT_NUM_RECORDS = 10

EPSILON = 1e-8
REPORTING_PRECISION = 2

Expand Down Expand Up @@ -255,3 +257,6 @@ class NordColor(Enum):
"zh_TW",
"zu_ZA",
]

NVIDIA_PROVIDER_NAME = "nvidia"
NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY"
98 changes: 86 additions & 12 deletions src/data_designer/config/utils/info.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,97 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from enum import Enum
from typing import Literal, TypeVar, overload

from ..models import ModelConfig, ModelProvider
from ..sampler_params import SamplerType
from .type_helpers import get_sampler_params
from .visualization import display_sampler_table
from .visualization import display_model_configs_table, display_model_providers_table, display_sampler_table


class InfoType(str, Enum):
SAMPLERS = "SAMPLERS"
MODEL_CONFIGS = "MODEL_CONFIGS"
MODEL_PROVIDERS = "MODEL_PROVIDERS"


ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS]
DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS]
InfoTypeT = TypeVar("InfoTypeT", bound=InfoType)


class InfoDisplay(ABC):
"""Base class for info display classes that provide type-safe display methods."""

@abstractmethod
def display(self, info_type: InfoTypeT, **kwargs) -> None:
"""Display information based on the provided info type.
Args:
info_type: Type of information to display.
"""
...


class DataDesignerInfo:
def __init__(self):
class ConfigBuilderInfo(InfoDisplay):
def __init__(self, model_configs: list[ModelConfig]):
self._sampler_params = get_sampler_params()
self._model_configs = model_configs

@overload
def display(self, info_type: Literal[InfoType.SAMPLERS], **kwargs) -> None: ...

@overload
def display(self, info_type: Literal[InfoType.MODEL_CONFIGS], **kwargs) -> None: ...

def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None:
"""Display information based on the provided info type.
Args:
info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported.
Raises:
ValueError: If an unsupported info_type is provided.
"""
if info_type == InfoType.SAMPLERS:
self._display_sampler_info(sampler_type=kwargs.get("sampler_type"))
elif info_type == InfoType.MODEL_CONFIGS:
display_model_configs_table(self._model_configs)
else:
raise ValueError(
f"Unsupported info_type: {str(info_type)!r}. "
f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}."
)

def _display_sampler_info(self, sampler_type: SamplerType | None) -> None:
if sampler_type is not None:
title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler"
display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title)
else:
display_sampler_table(self._sampler_params)


class InterfaceInfo(InfoDisplay):
def __init__(self, model_providers: list[ModelProvider]):
self._model_providers = model_providers

@overload
def display(self, info_type: Literal[InfoType.MODEL_PROVIDERS], **kwargs) -> None: ...

@property
def sampler_table(self) -> None:
display_sampler_table(self._sampler_params)
def display(self, info_type: DataDesignerInfoType, **kwargs) -> None:
"""Display information based on the provided info type.
@property
def sampler_types(self) -> list[str]:
return [s.value for s in SamplerType]
Args:
info_type: Type of information to display. Only MODEL_PROVIDERS is supported.
def display_sampler(self, sampler_type: SamplerType) -> None:
title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler"
display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title)
Raises:
ValueError: If an unsupported info_type is provided.
"""
if info_type == InfoType.MODEL_PROVIDERS:
display_model_providers_table(self._model_providers)
else:
raise ValueError(
f"Unsupported info_type: {str(info_type)!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}."
)
Loading
Loading