Skip to content

Commit aa22900

Browse files
authored
chore: (FTUE ) updated display pipeline for builder and interface like objects (#22)
1 parent affc46f commit aa22900

File tree

18 files changed

+435
-119
lines changed

18 files changed

+435
-119
lines changed

src/data_designer/config/base.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,14 @@
33

44
from __future__ import annotations
55

6-
from abc import ABC, abstractmethod
76
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union
7+
from typing import Any, Optional, Union
98

10-
import pandas as pd
119
from pydantic import BaseModel, ConfigDict
1210
import yaml
1311

1412
from .utils.io_helpers import serialize_data
1513

16-
if TYPE_CHECKING:
17-
from .analysis.dataset_profiler import DatasetProfilerResults
18-
from .config_builder import DataDesignerConfigBuilder
19-
from .preview_results import PreviewResults
20-
21-
DEFAULT_NUM_RECORDS = 10
22-
23-
24-
class ResultsProtocol(Protocol):
25-
def load_analysis(self) -> DatasetProfilerResults: ...
26-
def load_dataset(self) -> pd.DataFrame: ...
27-
28-
29-
ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)
30-
31-
32-
class DataDesignerInterface(ABC, Generic[ResultsT]):
33-
@abstractmethod
34-
def create(
35-
self,
36-
config_builder: DataDesignerConfigBuilder,
37-
*,
38-
num_records: int = DEFAULT_NUM_RECORDS,
39-
) -> ResultsT: ...
40-
41-
@abstractmethod
42-
def preview(
43-
self,
44-
config_builder: DataDesignerConfigBuilder,
45-
*,
46-
num_records: int = DEFAULT_NUM_RECORDS,
47-
) -> PreviewResults: ...
48-
4914

5015
class ConfigBase(BaseModel):
5116
model_config = ConfigDict(

src/data_designer/config/config_builder.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
SeedDatasetReference,
4545
)
4646
from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
47-
from .utils.info import DataDesignerInfo
47+
from .utils.info import ConfigBuilderInfo
4848
from .utils.io_helpers import serialize_data, smart_load_yaml
4949
from .utils.misc import (
5050
can_run_data_designer_locally,
@@ -132,22 +132,20 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:
132132

133133
return builder
134134

135-
def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None):
135+
def __init__(self, model_configs: Union[list[ModelConfig], str, Path]):
136136
"""Initialize a new DataDesignerConfigBuilder instance.
137137
138138
Args:
139-
model_configs: Optional model configurations. Can be:
139+
model_configs: Model configurations. Can be:
140140
- A list of ModelConfig objects
141141
- A string or Path to a model configuration file
142-
- None to use default model configurations
143142
"""
144143
self._column_configs = {}
145144
self._model_configs = load_model_configs(model_configs)
146145
self._processor_configs: list[ProcessorConfig] = []
147146
self._seed_config: Optional[SeedConfig] = None
148147
self._constraints: list[ColumnConstraintT] = []
149148
self._profilers: list[ColumnProfilerConfigT] = []
150-
self._info = DataDesignerInfo()
151149
self._datastore_settings: Optional[DatastoreSettings] = None
152150

153151
@property
@@ -173,13 +171,13 @@ def allowed_references(self) -> list[str]:
173171
return list(self._column_configs.keys()) + list(set(side_effect_columns))
174172

175173
@property
176-
def info(self) -> DataDesignerInfo:
177-
"""Get the DataDesignerInfo object for this builder.
174+
def info(self) -> ConfigBuilderInfo:
175+
"""Get the ConfigBuilderInfo object for this builder.
178176
179177
Returns:
180-
An object containing metadata about the configuration.
178+
An object containing information about the configuration.
181179
"""
182-
return self._info
180+
return ConfigBuilderInfo(model_configs=self._model_configs)
183181

184182
def add_model_config(self, model_config: ModelConfig) -> Self:
185183
"""Add a model configuration to the current Data Designer configuration.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from abc import ABC, abstractmethod
7+
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
8+
9+
import pandas as pd
10+
11+
from .models import ModelConfig, ModelProvider
12+
from .utils.constants import DEFAULT_NUM_RECORDS
13+
from .utils.info import InterfaceInfo
14+
15+
if TYPE_CHECKING:
16+
from .analysis.dataset_profiler import DatasetProfilerResults
17+
from .config_builder import DataDesignerConfigBuilder
18+
from .preview_results import PreviewResults
19+
20+
21+
class ResultsProtocol(Protocol):
22+
def load_analysis(self) -> DatasetProfilerResults: ...
23+
def load_dataset(self) -> pd.DataFrame: ...
24+
25+
26+
ResultsT = TypeVar("ResultsT", bound=ResultsProtocol)
27+
28+
29+
class DataDesignerInterface(ABC, Generic[ResultsT]):
30+
@abstractmethod
31+
def create(
32+
self,
33+
config_builder: DataDesignerConfigBuilder,
34+
*,
35+
num_records: int = DEFAULT_NUM_RECORDS,
36+
) -> ResultsT: ...
37+
38+
@abstractmethod
39+
def preview(
40+
self,
41+
config_builder: DataDesignerConfigBuilder,
42+
*,
43+
num_records: int = DEFAULT_NUM_RECORDS,
44+
) -> PreviewResults: ...
45+
46+
@abstractmethod
47+
def get_default_model_configs(self) -> list[ModelConfig]: ...
48+
49+
@abstractmethod
50+
def get_default_model_providers(self) -> list[ModelProvider]: ...
51+
52+
@property
53+
@abstractmethod
54+
def info(self) -> InterfaceInfo: ...

src/data_designer/config/models.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from abc import ABC, abstractmethod
55
from enum import Enum
6+
import logging
7+
import os
68
from pathlib import Path
79
from typing import Any, Generic, List, Optional, TypeVar, Union
810

@@ -12,9 +14,20 @@
1214

1315
from .base import ConfigBase
1416
from .errors import InvalidConfigError
15-
from .utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
17+
from .utils.constants import (
18+
MAX_TEMPERATURE,
19+
MAX_TOP_P,
20+
MIN_TEMPERATURE,
21+
MIN_TOP_P,
22+
NVIDIA_API_KEY_ENV_VAR_NAME,
23+
NVIDIA_PROVIDER_NAME,
24+
OPENAI_API_KEY_ENV_VAR_NAME,
25+
OPENAI_PROVIDER_NAME,
26+
)
1627
from .utils.io_helpers import smart_load_yaml
1728

29+
logger = logging.getLogger(__name__)
30+
1831

1932
class Modality(str, Enum):
2033
IMAGE = "image"
@@ -204,9 +217,14 @@ class ModelConfig(ConfigBase):
204217
provider: Optional[str] = None
205218

206219

207-
def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]:
208-
if model_configs is None:
209-
return []
220+
class ModelProvider(ConfigBase):
221+
name: str
222+
endpoint: str
223+
provider_type: str = "openai"
224+
api_key: str | None = None
225+
226+
227+
def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]:
210228
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
211229
return model_configs
212230
json_config = smart_load_yaml(model_configs)
@@ -215,3 +233,107 @@ def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None])
215233
"The list of model configs must be provided under model_configs in the configuration file."
216234
)
217235
return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
236+
237+
238+
def get_default_text_alias_inference_parameters() -> InferenceParameters:
239+
return InferenceParameters(
240+
temperature=0.85,
241+
top_p=0.95,
242+
)
243+
244+
245+
def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
246+
return InferenceParameters(
247+
temperature=0.35,
248+
top_p=0.95,
249+
)
250+
251+
252+
def get_default_vision_alias_inference_parameters() -> InferenceParameters:
253+
return InferenceParameters(
254+
temperature=0.85,
255+
top_p=0.95,
256+
)
257+
258+
259+
def get_default_nvidia_model_configs() -> list[ModelConfig]:
260+
if not get_nvidia_api_key():
261+
logger.warning(
262+
f"🔑 {NVIDIA_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://build.nvidia.com' if you want to use the default NVIDIA model configs."
263+
)
264+
return []
265+
return [
266+
ModelConfig(
267+
alias=f"{NVIDIA_PROVIDER_NAME}-text",
268+
model="nvidia/nvidia-nemotron-nano-9b-v2",
269+
provider=NVIDIA_PROVIDER_NAME,
270+
inference_parameters=get_default_text_alias_inference_parameters(),
271+
),
272+
ModelConfig(
273+
alias=f"{NVIDIA_PROVIDER_NAME}-reasoning",
274+
model="openai/gpt-oss-20b",
275+
provider=NVIDIA_PROVIDER_NAME,
276+
inference_parameters=get_default_reasoning_alias_inference_parameters(),
277+
),
278+
ModelConfig(
279+
alias=f"{NVIDIA_PROVIDER_NAME}-vision",
280+
model="nvidia/nemotron-nano-12b-v2-vl",
281+
provider=NVIDIA_PROVIDER_NAME,
282+
inference_parameters=get_default_vision_alias_inference_parameters(),
283+
),
284+
]
285+
286+
287+
def get_default_openai_model_configs() -> list[ModelConfig]:
288+
if not get_openai_api_key():
289+
logger.warning(
290+
f"🔑 {OPENAI_API_KEY_ENV_VAR_NAME!r} environment variable is not set. Please set it to your API key from 'https://platform.openai.com/api-keys' if you want to use the default OpenAI model configs."
291+
)
292+
return []
293+
return [
294+
ModelConfig(
295+
alias=f"{OPENAI_PROVIDER_NAME}-text",
296+
model="gpt-4.1",
297+
provider=OPENAI_PROVIDER_NAME,
298+
inference_parameters=get_default_text_alias_inference_parameters(),
299+
),
300+
ModelConfig(
301+
alias=f"{OPENAI_PROVIDER_NAME}-reasoning",
302+
model="gpt-5",
303+
provider=OPENAI_PROVIDER_NAME,
304+
inference_parameters=get_default_reasoning_alias_inference_parameters(),
305+
),
306+
ModelConfig(
307+
alias=f"{OPENAI_PROVIDER_NAME}-vision",
308+
model="gpt-5",
309+
provider=OPENAI_PROVIDER_NAME,
310+
inference_parameters=get_default_vision_alias_inference_parameters(),
311+
),
312+
]
313+
314+
315+
def get_default_model_configs() -> list[ModelConfig]:
316+
return get_default_nvidia_model_configs() + get_default_openai_model_configs()
317+
318+
319+
def get_default_providers() -> list[ModelProvider]:
320+
return [
321+
ModelProvider(
322+
name=NVIDIA_PROVIDER_NAME,
323+
endpoint="https://integrate.api.nvidia.com/v1",
324+
api_key=NVIDIA_API_KEY_ENV_VAR_NAME,
325+
),
326+
ModelProvider(
327+
name=OPENAI_PROVIDER_NAME,
328+
endpoint="https://api.openai.com/v1",
329+
api_key=OPENAI_API_KEY_ENV_VAR_NAME,
330+
),
331+
]
332+
333+
334+
def get_nvidia_api_key() -> Optional[str]:
335+
return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
336+
337+
338+
def get_openai_api_key() -> Optional[str]:
339+
return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)

src/data_designer/config/utils/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from rich.theme import Theme
77

8+
DEFAULT_NUM_RECORDS = 10
9+
810
EPSILON = 1e-8
911
REPORTING_PRECISION = 2
1012

@@ -255,3 +257,9 @@ class NordColor(Enum):
255257
"zh_TW",
256258
"zu_ZA",
257259
]
260+
261+
NVIDIA_PROVIDER_NAME = "nvidia"
262+
NVIDIA_API_KEY_ENV_VAR_NAME = "NVIDIA_API_KEY"
263+
264+
OPENAI_PROVIDER_NAME = "openai"
265+
OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"

0 commit comments

Comments
 (0)