Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions src/data_designer/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,14 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union
from typing import Any, Optional, Union

import pandas as pd
from pydantic import BaseModel, ConfigDict
import yaml

from .utils.io_helpers import serialize_data

if TYPE_CHECKING:
from .analysis.dataset_profiler import DatasetProfilerResults
from .config_builder import DataDesignerConfigBuilder
from .preview_results import PreviewResults

DEFAULT_NUM_RECORDS = 10


class ResultsProtocol(Protocol):
def load_analysis(self) -> DatasetProfilerResults: ...
def load_dataset(self) -> pd.DataFrame: ...


ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)


class DataDesignerInterface(ABC, Generic[ResultsT]):
@abstractmethod
def create(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> ResultsT: ...

@abstractmethod
def preview(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> PreviewResults: ...


class ConfigBase(BaseModel):
model_config = ConfigDict(
Expand Down
16 changes: 7 additions & 9 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
SeedDatasetReference,
)
from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
from .utils.info import DataDesignerInfo
from .utils.info import ConfigBuilderInfo
from .utils.io_helpers import serialize_data, smart_load_yaml
from .utils.misc import (
can_run_data_designer_locally,
Expand Down Expand Up @@ -132,22 +132,20 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:

return builder

def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None):
def __init__(self, model_configs: Union[list[ModelConfig], str, Path]):
"""Initialize a new DataDesignerConfigBuilder instance.

Args:
model_configs: Optional model configurations. Can be:
model_configs: Model configurations. Can be:
- A list of ModelConfig objects
- A string or Path to a model configuration file
- None to use default model configurations
"""
self._column_configs = {}
self._model_configs = load_model_configs(model_configs)
self._processor_configs: list[ProcessorConfig] = []
self._seed_config: Optional[SeedConfig] = None
self._constraints: list[ColumnConstraintT] = []
self._profilers: list[ColumnProfilerConfigT] = []
self._info = DataDesignerInfo()
self._datastore_settings: Optional[DatastoreSettings] = None

@property
Expand All @@ -173,13 +171,13 @@ def allowed_references(self) -> list[str]:
return list(self._column_configs.keys()) + list(set(side_effect_columns))

@property
def info(self) -> DataDesignerInfo:
"""Get the DataDesignerInfo object for this builder.
def info(self) -> ConfigBuilderInfo:
"""Get the ConfigBuilderInfo object for this builder.

Returns:
An object containing metadata about the configuration.
An object containing information about the configuration.
"""
return self._info
return ConfigBuilderInfo(model_configs=self._model_configs)

def add_model_config(self, model_config: ModelConfig) -> Self:
"""Add a model configuration to the current Data Designer configuration.
Expand Down
54 changes: 54 additions & 0 deletions src/data_designer/config/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar

import pandas as pd

from .models import ModelConfig, ModelProvider
from .utils.constants import DEFAULT_NUM_RECORDS
from .utils.info import InterfaceInfo

if TYPE_CHECKING:
from .analysis.dataset_profiler import DatasetProfilerResults
from .config_builder import DataDesignerConfigBuilder
from .preview_results import PreviewResults


class ResultsProtocol(Protocol):
def load_analysis(self) -> DatasetProfilerResults: ...
def load_dataset(self) -> pd.DataFrame: ...


ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)


class DataDesignerInterface(ABC, Generic[ResultsT]):
@abstractmethod
def create(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> ResultsT: ...

@abstractmethod
def preview(
self,
config_builder: DataDesignerConfigBuilder,
*,
num_records: int = DEFAULT_NUM_RECORDS,
) -> PreviewResults: ...

@abstractmethod
def get_default_model_configs(self) -> list[ModelConfig]: ...

@abstractmethod
def get_default_model_providers(self) -> list[ModelProvider]: ...

@property
@abstractmethod
def info(self) -> InterfaceInfo: ...
Comment on lines +46 to +54
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikeknep just want to make sure you are are of these changes for the nmp dd client.

130 changes: 126 additions & 4 deletions src/data_designer/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from abc import ABC, abstractmethod
from enum import Enum
import logging
import os
from pathlib import Path
from typing import Any, Generic, List, Optional, TypeVar, Union

Expand All @@ -12,9 +14,20 @@

from .base import ConfigBase
from .errors import InvalidConfigError
from .utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
from .utils.constants import (
MAX_TEMPERATURE,
MAX_TOP_P,
MIN_TEMPERATURE,
MIN_TOP_P,
NVIDIA_API_KEY_ENV_VAR_NAME,
NVIDIA_PROVIDER_NAME,
OPENAI_API_KEY_ENV_VAR_NAME,
OPENAI_PROVIDER_NAME,
)
from .utils.io_helpers import smart_load_yaml

logger = logging.getLogger(__name__)


class Modality(str, Enum):
IMAGE = "image"
Expand Down Expand Up @@ -204,9 +217,14 @@ class ModelConfig(ConfigBase):
provider: Optional[str] = None


def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]:
if model_configs is None:
return []
class ModelProvider(ConfigBase):
name: str
endpoint: str
provider_type: str = "openai"
api_key: str | None = None


def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
return model_configs
json_config = smart_load_yaml(model_configs)
Expand All @@ -215,3 +233,107 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None])
"The list of model configs must be provided under model_configs in the configuration file."
)
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]


def get_default_text_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
temperature=0.85,
top_p=0.95,
)


def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
temperature=0.35,
top_p=0.95,
)


def get_default_vision_alias_inference_parameters() -> InferenceParameters:
return InferenceParameters(
temperature=0.85,
top_p=0.95,
)


def get_default_nvidia_model_configs() -> list[ModelConfig]:
if not get_nvidia_api_key():
logger.warning(
f"🔑 {NVIDIA_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs."
)
return []
return [
ModelConfig(
alias=f"{NVIDIA_PROVIDER_NAME}-text",
model="nvidia/nvidia-nemotron-nano-9b-v2",
provider=NVIDIA_PROVIDER_NAME,
inference_parameters=get_default_text_alias_inference_parameters(),
),
ModelConfig(
alias=f"{NVIDIA_PROVIDER_NAME}-reasoning",
model="openai/gpt-oss-20b",
provider=NVIDIA_PROVIDER_NAME,
inference_parameters=get_default_reasoning_alias_inference_parameters(),
),
ModelConfig(
alias=f"{NVIDIA_PROVIDER_NAME}-vision",
model="nvidia/nemotron-nano-12b-v2-vl",
provider=NVIDIA_PROVIDER_NAME,
inference_parameters=get_default_vision_alias_inference_parameters(),
),
]


def get_default_openai_model_configs() -> list[ModelConfig]:
if not get_openai_api_key():
logger.warning(
f"🔑 {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs."
)
return []
return [
ModelConfig(
alias=f"{OPENAI_PROVIDER_NAME}-text",
model="gpt-4.1",
provider=OPENAI_PROVIDER_NAME,
inference_parameters=get_default_text_alias_inference_parameters(),
),
ModelConfig(
alias=f"{OPENAI_PROVIDER_NAME}-reasoning",
model="gpt-5",
provider=OPENAI_PROVIDER_NAME,
inference_parameters=get_default_reasoning_alias_inference_parameters(),
),
ModelConfig(
alias=f"{OPENAI_PROVIDER_NAME}-vision",
model="gpt-5",
provider=OPENAI_PROVIDER_NAME,
inference_parameters=get_default_vision_alias_inference_parameters(),
),
]


def get_default_model_configs() -> list[ModelConfig]:
return get_default_nvidia_model_configs() + get_default_openai_model_configs()


def get_default_providers() -> list[ModelProvider]:
return [
ModelProvider(
name=NVIDIA_PROVIDER_NAME,
endpoint="https://integrate.api.nvidia.com/v1",
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
),
ModelProvider(
name=OPENAI_PROVIDER_NAME,
endpoint="https://api.openai.com/v1",
api_key=OPENAI_API_KEY_ENV_VAR_NAME,
),
]


def get_nvidia_api_key() -> Optional[str]:
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)


def get_openai_api_key() -> Optional[str]:
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
8 changes: 8 additions & 0 deletions src/data_designer/config/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from rich.theme import Theme

DEFAULT_NUM_RECORDS = 10

EPSILON = 1e-8
REPORTING_PRECISION = 2

Expand Down Expand Up @@ -255,3 +257,9 @@ class NordColor(Enum):
"zh_TW",
"zu_ZA",
]

NVIDIA_PROVIDER_NAME = "nvidia"
NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY"

OPENAI_PROVIDER_NAME = "openai"
OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
Loading