Skip to content

Commit e8e22be

Browse files
committed
Update info pipeine
1 parent 446a00e commit e8e22be

File tree

8 files changed

+177
-63
lines changed

8 files changed

+177
-63
lines changed

src/data_designer/config/config_builder.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
SeedDatasetReference,
4545
)
4646
from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
47-
from .utils.info import DataDesignerInfo
47+
from .utils.info import ConfigBuilderInfo
4848
from .utils.io_helpers import serialize_data, smart_load_yaml
4949
from .utils.misc import (
5050
can_run_data_designer_locally,
@@ -141,8 +141,7 @@ def __init__(self, model_configs: Union[list[ModelConfig], str, Path]):
141141
- A string or Path to a model configuration file
142142
"""
143143
self._column_configs = {}
144-
self._info = DataDesignerInfo()
145-
self.set_model_configs(load_model_configs(model_configs))
144+
self._model_configs = load_model_configs(model_configs)
146145
self._processor_configs: list[ProcessorConfig] = []
147146
self._seed_config: Optional[SeedConfig] = None
148147
self._constraints: list[ColumnConstraintT] = []
@@ -172,23 +171,13 @@ def allowed_references(self) -> list[str]:
172171
return list(self._column_configs.keys()) + list(set(side_effect_columns))
173172

174173
@property
175-
def info(self) -> DataDesignerInfo:
176-
"""Get the DataDesignerInfo object for this builder.
174+
def info(self) -> ConfigBuilderInfo:
175+
"""Get the ConfigBuilderInfo object for this builder.
177176
178177
Returns:
179-
An object containing metadata about the configuration.
178+
An object containing information about the configuration.
180179
"""
181-
return self._info
182-
183-
def set_model_configs(self, model_configs: list[ModelConfig]) -> Self:
184-
"""Set the model configurations for this builder.
185-
186-
Args:
187-
model_configs: The model configurations to set.
188-
"""
189-
self._model_configs = model_configs
190-
self._info.set_model_configs(model_configs=self._model_configs)
191-
return self
180+
return ConfigBuilderInfo(model_configs=self._model_configs)
192181

193182
def add_model_config(self, model_config: ModelConfig) -> Self:
194183
"""Add a model configuration to the current Data Designer configuration.
@@ -200,7 +189,7 @@ def add_model_config(self, model_config: ModelConfig) -> Self:
200189
raise BuilderConfigurationError(
201190
f"🛑 Model configuration with alias {model_config.alias} already exists. Please delete the existing model configuration or choose a different alias."
202191
)
203-
self.set_model_configs(self._model_configs + [model_config])
192+
self._model_configs.append(model_config)
204193
return self
205194

206195
def delete_model_config(self, alias: str) -> Self:
@@ -209,7 +198,7 @@ def delete_model_config(self, alias: str) -> Self:
209198
Args:
210199
alias: The alias of the model configuration to delete.
211200
"""
212-
self.set_model_configs([mc for mc in self._model_configs if mc.alias != alias])
201+
self._model_configs = [mc for mc in self._model_configs if mc.alias != alias]
213202
if len(self._model_configs) == 0:
214203
logger.warning(
215204
f"⚠️ No model configurations found after deleting model configuration with alias {alias}. Please add a model configuration before building the configuration."

src/data_designer/config/interface.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .models import ModelConfig, ModelProvider
1212
from .utils.constants import DEFAULT_NUM_RECORDS
13+
from .utils.info import InterfaceInfo
1314

1415
if TYPE_CHECKING:
1516
from .analysis.dataset_profiler import DatasetProfilerResults
@@ -47,3 +48,7 @@ def get_default_model_configs(self) -> list[ModelConfig]: ...
4748

4849
@abstractmethod
4950
def get_default_model_providers(self) -> list[ModelProvider]: ...
51+
52+
@property
53+
@abstractmethod
54+
def info(self) -> InterfaceInfo: ...
Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,97 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from ..models import ModelConfig
4+
from abc import ABC, abstractmethod
5+
from enum import Enum
6+
from typing import Literal, TypeVar, overload
7+
8+
from ..models import ModelConfig, ModelProvider
59
from ..sampler_params import SamplerType
610
from .type_helpers import get_sampler_params
7-
from .visualization import display_model_configs_table, display_sampler_table
11+
from .visualization import display_model_configs_table, display_model_providers_table, display_sampler_table
812

913

10-
class DataDesignerInfo:
11-
def __init__(self, model_configs: list[ModelConfig] | None = None):
12-
self._sampler_params = get_sampler_params()
13-
self._model_configs = model_configs or []
14+
class InfoType(str, Enum):
15+
SAMPLERS = "SAMPLERS"
16+
MODEL_CONFIGS = "MODEL_CONFIGS"
17+
MODEL_PROVIDERS = "MODEL_PROVIDERS"
18+
19+
20+
ConfigBuilderInfoType = Literal[InfoType.SAMPLERS, InfoType.MODEL_CONFIGS]
21+
DataDesignerInfoType = Literal[InfoType.MODEL_PROVIDERS]
22+
InfoTypeT = TypeVar("InfoTypeT", bound=InfoType)
23+
1424

15-
@property
16-
def sampler_table(self) -> None:
17-
display_sampler_table(self._sampler_params)
25+
class InfoDisplay(ABC):
26+
"""Base class for info display classes that provide type-safe display methods."""
1827

19-
@property
20-
def sampler_types(self) -> list[str]:
21-
return [s.value for s in SamplerType]
28+
@abstractmethod
29+
def display(self, info_type: InfoTypeT, **kwargs) -> None:
30+
"""Display information based on the provided info type.
2231
23-
def set_model_configs(self, model_configs: list[ModelConfig]) -> None:
32+
Args:
33+
info_type: Type of information to display.
34+
"""
35+
...
36+
37+
38+
class ConfigBuilderInfo(InfoDisplay):
39+
def __init__(self, model_configs: list[ModelConfig]):
40+
self._sampler_params = get_sampler_params()
2441
self._model_configs = model_configs
2542

26-
def display_sampler(self, sampler_type: SamplerType) -> None:
27-
title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler"
28-
display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title)
43+
@overload
44+
def display(self, info_type: Literal[InfoType.SAMPLERS], **kwargs) -> None: ...
45+
46+
@overload
47+
def display(self, info_type: Literal[InfoType.MODEL_CONFIGS], **kwargs) -> None: ...
48+
49+
def display(self, info_type: ConfigBuilderInfoType, **kwargs) -> None:
50+
"""Display information based on the provided info type.
51+
52+
Args:
53+
info_type: Type of information to display. Only SAMPLERS and MODEL_CONFIGS are supported.
54+
55+
Raises:
56+
ValueError: If an unsupported info_type is provided.
57+
"""
58+
if info_type == InfoType.SAMPLERS:
59+
self._display_sampler_info(sampler_type=kwargs.get("sampler_type"))
60+
elif info_type == InfoType.MODEL_CONFIGS:
61+
display_model_configs_table(self._model_configs)
62+
else:
63+
raise ValueError(
64+
f"Unsupported info_type: {info_type!r}. "
65+
f"ConfigBuilderInfo only supports {InfoType.SAMPLERS.value!r} and {InfoType.MODEL_CONFIGS.value!r}."
66+
)
67+
68+
def _display_sampler_info(self, sampler_type: SamplerType | None) -> None:
69+
if sampler_type is not None:
70+
title = f"{SamplerType(sampler_type).value.replace('_', ' ').title()} Sampler"
71+
display_sampler_table({sampler_type: self._sampler_params[sampler_type]}, title=title)
72+
else:
73+
display_sampler_table(self._sampler_params)
74+
75+
76+
class InterfaceInfo(InfoDisplay):
77+
def __init__(self, model_providers: list[ModelProvider]):
78+
self._model_providers = model_providers
79+
80+
@overload
81+
def display(self, info_type: Literal[InfoType.MODEL_PROVIDERS], **kwargs) -> None: ...
82+
83+
def display(self, info_type: DataDesignerInfoType, **kwargs) -> None:
84+
"""Display information based on the provided info type.
85+
86+
Args:
87+
info_type: Type of information to display. Only MODEL_PROVIDERS is supported.
2988
30-
def display_model_configs(self) -> None:
31-
display_model_configs_table(model_configs=self._model_configs)
89+
Raises:
90+
ValueError: If an unsupported info_type is provided.
91+
"""
92+
if info_type == InfoType.MODEL_PROVIDERS:
93+
display_model_providers_table(self._model_providers)
94+
else:
95+
raise ValueError(
96+
f"Unsupported info_type: {info_type!r}. InterfaceInfo only supports {InfoType.MODEL_PROVIDERS.value!r}."
97+
)

src/data_designer/config/utils/visualization.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ..base import ConfigBase
2424
from ..columns import DataDesignerColumnType
25-
from ..models import ModelConfig
25+
from ..models import ModelConfig, ModelProvider
2626
from ..sampler_params import SamplerType
2727
from .code_lang import code_lang_to_syntax_lexer
2828
from .errors import DatasetSampleDisplayError
@@ -278,6 +278,17 @@ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
278278
console.print(group)
279279

280280

281+
def display_model_providers_table(model_providers: list[ModelProvider]) -> None:
282+
table_model_providers = Table(expand=True)
283+
table_model_providers.add_column("Name")
284+
table_model_providers.add_column("Endpoint")
285+
table_model_providers.add_column("API Key")
286+
for model_provider in model_providers:
287+
table_model_providers.add_row(model_provider.name, model_provider.endpoint, model_provider.api_key)
288+
group = Group(Rule(title="Model Providers"), table_model_providers)
289+
console.print(group)
290+
291+
281292
def convert_to_row_element(elem):
282293
try:
283294
elem = Pretty(json.loads(elem))

src/data_designer/essentials/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from ..config.seed import DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig
5353
from ..config.utils.code_lang import CodeLang
54+
from ..config.utils.info import InfoType
5455
from ..config.utils.misc import can_run_data_designer_locally
5556
from ..config.validator_params import (
5657
CodeValidatorParams,
@@ -90,6 +91,7 @@
9091
"ExpressionColumnConfig",
9192
"GaussianSamplerParams",
9293
"IndexRange",
94+
"InfoType",
9395
"ImageContext",
9496
"ImageFormat",
9597
"InferenceParameters",

src/data_designer/interface/data_designer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from data_designer.config.preview_results import PreviewResults
1919
from data_designer.config.seed import LocalSeedDatasetReference
2020
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
21+
from data_designer.config.utils.info import InterfaceInfo
2122
from data_designer.config.utils.io_helpers import write_seed_dataset
2223
from data_designer.engine.analysis.dataset_profiler import (
2324
DataDesignerDatasetProfiler,
@@ -83,8 +84,8 @@ def __init__(
8384
if blob_storage_path is None
8485
else init_managed_blob_storage(str(blob_storage_path))
8586
)
86-
model_providers = model_providers or self.get_default_model_providers()
87-
self._model_provider_registry = resolve_model_provider_registry(model_providers)
87+
self._model_providers = model_providers or self.get_default_model_providers()
88+
self._model_provider_registry = resolve_model_provider_registry(self._model_providers)
8889

8990
@staticmethod
9091
def make_seed_reference_from_file(file_path: str | Path) -> LocalSeedDatasetReference:
@@ -122,6 +123,10 @@ def make_seed_reference_from_dataframe(
122123
write_seed_dataset(dataframe, Path(file_path))
123124
return cls.make_seed_reference_from_file(file_path)
124125

126+
@property
127+
def info(self) -> InterfaceInfo:
128+
return InterfaceInfo(model_providers=self._model_providers)
129+
125130
def create(
126131
self,
127132
config_builder: DataDesignerConfigBuilder,

tests/config/utils/test_info.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,58 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
43
from unittest.mock import patch
54

6-
from data_designer.config.sampler_params import BernoulliSamplerParams, BinomialSamplerParams, SamplerType
7-
from data_designer.config.utils.info import DataDesignerInfo
5+
import pytest
6+
7+
from data_designer.config.sampler_params import SamplerType
8+
from data_designer.config.utils.info import ConfigBuilderInfo, InfoType, InterfaceInfo
9+
from data_designer.config.utils.type_helpers import get_sampler_params
810

911

1012
@patch("data_designer.config.utils.info.display_sampler_table")
11-
@patch("data_designer.config.utils.info.get_sampler_params")
12-
def test_data_designer_info(mock_get_sampler_params, mock_display_sampler_table):
13-
stub_bernoulli_params = BernoulliSamplerParams(p=0.5)
14-
stub_binomial_params = BinomialSamplerParams(n=100, p=0.5)
15-
mock_get_sampler_params.return_value = {
16-
SamplerType.BERNOULLI: stub_bernoulli_params,
17-
SamplerType.BINOMIAL: stub_binomial_params,
18-
}
19-
info = DataDesignerInfo()
20-
21-
assert SamplerType.BINOMIAL.value in info.sampler_types
22-
mock_get_sampler_params.assert_called_once()
23-
24-
_ = info.sampler_table
25-
mock_display_sampler_table.assert_called_once_with(
26-
{SamplerType.BERNOULLI: stub_bernoulli_params, SamplerType.BINOMIAL: stub_binomial_params}
27-
)
13+
@patch("data_designer.config.utils.info.display_model_configs_table")
14+
def test_config_builder_sampler_info(mock_display_model_configs_table, mock_display_sampler_table, stub_model_configs):
15+
info = ConfigBuilderInfo(model_configs=stub_model_configs)
16+
info.display(InfoType.MODEL_CONFIGS)
17+
mock_display_model_configs_table.assert_called_once_with(stub_model_configs)
18+
19+
sampler_params = get_sampler_params()
20+
info.display(InfoType.SAMPLERS)
21+
mock_display_sampler_table.assert_called_once_with(sampler_params)
2822

2923
mock_display_sampler_table.reset_mock()
30-
info.display_sampler(SamplerType.BERNOULLI)
24+
info.display(InfoType.SAMPLERS, sampler_type=SamplerType.BERNOULLI)
3125
mock_display_sampler_table.assert_called_once_with(
32-
{SamplerType.BERNOULLI: stub_bernoulli_params}, title="Bernoulli Sampler"
26+
{SamplerType.BERNOULLI: sampler_params[SamplerType.BERNOULLI]}, title="Bernoulli Sampler"
3327
)
28+
29+
30+
@patch("data_designer.config.utils.info.display_model_configs_table")
31+
def test_config_builder_model_configs_info(mock_display_model_configs_table, stub_model_configs):
32+
info = ConfigBuilderInfo(model_configs=stub_model_configs)
33+
info.display(InfoType.MODEL_CONFIGS)
34+
mock_display_model_configs_table.assert_called_once_with(stub_model_configs)
35+
36+
37+
def test_config_builder_unsupported_info_type(stub_model_configs):
38+
info = ConfigBuilderInfo(model_configs=stub_model_configs)
39+
with pytest.raises(
40+
ValueError,
41+
match="Unsupported info_type: 'unsupported_type'. ConfigBuilderInfo only supports 'SAMPLERS' and 'MODEL_CONFIGS'.",
42+
):
43+
info.display("unsupported_type")
44+
45+
46+
@patch("data_designer.config.utils.info.display_model_providers_table")
47+
def test_interface_model_providers_info(mock_display_model_providers_table, stub_model_providers):
48+
info = InterfaceInfo(model_providers=stub_model_providers)
49+
info.display(InfoType.MODEL_PROVIDERS)
50+
mock_display_model_providers_table.assert_called_once_with(stub_model_providers)
51+
52+
53+
def test_interface_unsupported_info_type(stub_model_providers):
54+
info = InterfaceInfo(model_providers=stub_model_providers)
55+
with pytest.raises(
56+
ValueError, match="Unsupported info_type: 'unsupported_type'. InterfaceInfo only supports 'MODEL_PROVIDERS'."
57+
):
58+
info.display("unsupported_type")

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from data_designer.config.config_builder import DataDesignerConfigBuilder
1818
from data_designer.config.data_designer_config import DataDesignerConfig
1919
from data_designer.config.datastore import DatastoreSettings
20-
from data_designer.config.models import InferenceParameters, ModelConfig
20+
from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
2121

2222

2323
@pytest.fixture
@@ -144,6 +144,17 @@ def stub_model_configs() -> list[ModelConfig]:
144144
]
145145

146146

147+
@pytest.fixture
148+
def stub_model_providers() -> list[ModelProvider]:
149+
return [
150+
ModelProvider(
151+
name="provider-1",
152+
endpoint="https://api.provider-1.com/v1",
153+
api_key="PROVIDER_1_API_KEY",
154+
)
155+
]
156+
157+
147158
@pytest.fixture
148159
def stub_empty_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerConfigBuilder:
149160
"""Test builder with model configs."""

0 commit comments

Comments
 (0)