Skip to content

Commit 08ba0d2

Browse files
committed
Refactor seed datasets
1 parent d7e93c5 commit 08ba0d2

32 files changed

+753
-1275
lines changed

src/data_designer/config/config_builder.py

Lines changed: 39 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
)
2525
from data_designer.config.data_designer_config import DataDesignerConfig
2626
from data_designer.config.dataset_builders import BuildStage
27-
from data_designer.config.datastore import DatastoreSettings, fetch_seed_dataset_column_names
2827
from data_designer.config.default_model_settings import get_default_model_configs
29-
from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
28+
from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError
3029
from data_designer.config.models import ModelConfig, load_model_configs
3130
from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs
3231
from data_designer.config.sampler_constraints import (
@@ -36,20 +35,17 @@
3635
ScalarInequalityConstraint,
3736
)
3837
from data_designer.config.seed import (
39-
DatastoreSeedDatasetReference,
4038
IndexRange,
41-
LocalSeedDatasetReference,
4239
PartitionBlock,
4340
SamplingStrategy,
4441
SeedConfig,
45-
SeedDatasetReference,
4642
)
43+
from data_designer.config.seed_dataset import DataFrameSeedConfig, SeedDatasetConfig
4744
from data_designer.config.utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
4845
from data_designer.config.utils.info import ConfigBuilderInfo
4946
from data_designer.config.utils.io_helpers import serialize_data, smart_load_yaml
5047
from data_designer.config.utils.misc import can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake
5148
from data_designer.config.utils.type_helpers import resolve_string_enum
52-
from data_designer.config.utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
5349

5450
logger = logging.getLogger(__name__)
5551

@@ -63,12 +59,9 @@ class BuilderConfig(ExportableConfigBase):
6359
Attributes:
6460
data_designer: The main Data Designer configuration containing columns,
6561
constraints, profilers, and other settings.
66-
datastore_settings: Optional datastore settings for accessing external
67-
datasets.
6862
"""
6963

7064
data_designer: DataDesignerConfig
71-
datastore_settings: DatastoreSettings | None
7265

7366

7467
class DataDesignerConfigBuilder:
@@ -101,31 +94,27 @@ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self:
10194
builder_config = BuilderConfig.model_validate(json_config)
10295

10396
builder = cls(model_configs=builder_config.data_designer.model_configs)
104-
config = builder_config.data_designer
97+
data_designer_config = builder_config.data_designer
10598

106-
for col in config.columns:
99+
for col in data_designer_config.columns:
107100
builder.add_column(col)
108101

109-
for constraint in config.constraints or []:
102+
for constraint in data_designer_config.constraints or []:
110103
builder.add_constraint(constraint=constraint)
111104

112-
if config.seed_config:
113-
if builder_config.datastore_settings is None:
114-
if can_run_data_designer_locally():
115-
seed_dataset_reference = LocalSeedDatasetReference(dataset=config.seed_config.dataset)
116-
else:
117-
raise BuilderConfigurationError("🛑 Datastore settings are required.")
105+
if (seed_config := data_designer_config.seed_config) is not None:
106+
if isinstance(seed_config.config, DataFrameSeedConfig):
107+
logger.warning(
108+
"This builder was originally configured with a DataFrame seed dataset. "
109+
"DataFrame seeds cannot be serialized to config files. "
110+
"You must re-run `with_seed_dataset` to reconfigure the seed data."
111+
)
118112
else:
119-
seed_dataset_reference = DatastoreSeedDatasetReference(
120-
dataset=config.seed_config.dataset,
121-
datastore_settings=builder_config.datastore_settings,
113+
builder.with_seed_dataset(
114+
seed_config.config,
115+
sampling_strategy=seed_config.sampling_strategy,
116+
selection_strategy=seed_config.selection_strategy,
122117
)
123-
builder.set_seed_datastore_settings(builder_config.datastore_settings)
124-
builder.with_seed_dataset(
125-
seed_dataset_reference,
126-
sampling_strategy=config.seed_config.sampling_strategy,
127-
selection_strategy=config.seed_config.selection_strategy,
128-
)
129118

130119
return builder
131120

@@ -144,7 +133,6 @@ def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None):
144133
self._seed_config: SeedConfig | None = None
145134
self._constraints: list[ColumnConstraintT] = []
146135
self._profilers: list[ColumnProfilerConfigT] = []
147-
self._datastore_settings: DatastoreSettings | None = None
148136

149137
@property
150138
def model_configs(self) -> list[ModelConfig]:
@@ -243,6 +231,10 @@ def add_column(
243231
f"{', '.join([t.__name__ for t in allowed_column_configs])}"
244232
)
245233

234+
# TODO: the config builder will no longer have any SeedDatasetColumnConfigs, because seed columns
235+
# aren't fetched until we get to the new compiler fn in engine code. We could just remove this.
236+
# Or, should we keep it for all columns (not just seeds)? Is there any reason to add a column and
237+
# then overwrite it? (Alternatively, we could keep it but just log a warning instead of raising.)
246238
existing_config = self._column_configs.get(column_config.name)
247239
if existing_config is not None and isinstance(existing_config, SeedDatasetColumnConfig):
248240
raise BuilderConfigurationError(
@@ -371,19 +363,12 @@ def get_profilers(self) -> list[ColumnProfilerConfigT]:
371363
"""
372364
return self._profilers
373365

374-
def build(self, *, skip_validation: bool = False, raise_exceptions: bool = False) -> DataDesignerConfig:
366+
def build(self) -> DataDesignerConfig:
375367
"""Build a DataDesignerConfig instance based on the current builder configuration.
376368
377-
Args:
378-
skip_validation: Whether to skip validation of the configuration.
379-
raise_exceptions: Whether to raise an exception if the configuration is invalid.
380-
381369
Returns:
382370
The current Data Designer config object.
383371
"""
384-
if not skip_validation:
385-
self.validate(raise_exceptions=raise_exceptions)
386-
387372
return DataDesignerConfig(
388373
model_configs=self._model_configs,
389374
seed_config=self._seed_config,
@@ -512,14 +497,6 @@ def get_seed_config(self) -> SeedConfig | None:
512497
"""
513498
return self._seed_config
514499

515-
def get_seed_datastore_settings(self) -> DatastoreSettings | None:
516-
"""Get most recent datastore settings for the current Data Designer configuration.
517-
518-
Returns:
519-
The datastore settings if configured, None otherwise.
520-
"""
521-
return None if not self._datastore_settings else DatastoreSettings.model_validate(self._datastore_settings)
522-
523500
def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int:
524501
"""Get the count of columns of the specified type.
525502
@@ -531,85 +508,33 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int:
531508
"""
532509
return len(self.get_columns_of_type(column_type))
533510

534-
def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self:
535-
"""Set the datastore settings for the seed dataset.
536-
537-
Args:
538-
datastore_settings: The datastore settings to use for the seed dataset.
539-
"""
540-
self._datastore_settings = datastore_settings
541-
return self
542-
543-
def validate(self, *, raise_exceptions: bool = False) -> Self:
544-
"""Validate the current Data Designer configuration.
545-
546-
Args:
547-
raise_exceptions: Whether to raise an exception if the configuration is invalid.
548-
549-
Returns:
550-
The current Data Designer config builder instance.
551-
552-
Raises:
553-
InvalidConfigError: If the configuration is invalid and raise_exceptions is True.
554-
"""
555-
556-
violations = validate_data_designer_config(
557-
columns=list(self._column_configs.values()),
558-
processor_configs=self._processor_configs,
559-
allowed_references=self.allowed_references,
560-
)
561-
rich_print_violations(violations)
562-
if raise_exceptions and len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
563-
raise InvalidConfigError(
564-
"🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
565-
)
566-
if len(violations) == 0:
567-
logger.info("✅ Validation passed")
568-
return self
569-
570511
def with_seed_dataset(
571512
self,
572-
dataset_reference: SeedDatasetReference,
513+
seed_dataset_config: SeedDatasetConfig,
573514
*,
574515
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
575516
selection_strategy: IndexRange | PartitionBlock | None = None,
576517
) -> Self:
577518
"""Add a seed dataset to the current Data Designer configuration.
578519
579-
This method sets the seed dataset for the configuration and automatically creates
580-
SeedDatasetColumnConfig objects for each column found in the dataset. The column
581-
names are fetched from the dataset source, which can be the Hugging Face Hub, the
582-
NeMo Microservices Datastore, or in the case of direct library usage, a local file.
520+
This method sets the seed dataset for the configuration, but columns are not resolved until
521+
compilation (including validation) is performed by the engine using a SeedDatasetReader.
583522
584523
Args:
585-
dataset_reference: Seed dataset reference for fetching from the datastore.
524+
seed_dataset_config: The config providing a pointer to the seed dataset.
586525
sampling_strategy: The sampling strategy to use when generating data from the seed dataset.
587526
Defaults to ORDERED sampling.
527+
selection_strategy: An optional selection strategy to use when generating data from the seed dataset.
528+
Defaults to None.
588529
589530
Returns:
590531
The current Data Designer config builder instance.
591-
592-
Raises:
593-
BuilderConfigurationError: If any seed dataset column name collides with an existing column.
594532
"""
595-
seed_column_names = fetch_seed_dataset_column_names(dataset_reference)
596-
colliding_columns = [name for name in seed_column_names if name in self._column_configs]
597-
if colliding_columns:
598-
raise BuilderConfigurationError(
599-
f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
600-
"Please remove the conflicting columns or use a seed dataset with different column names."
601-
)
602-
603533
self._seed_config = SeedConfig(
604-
dataset=dataset_reference.dataset,
534+
config=seed_dataset_config,
605535
sampling_strategy=sampling_strategy,
606536
selection_strategy=selection_strategy,
607537
)
608-
self.set_seed_datastore_settings(
609-
dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None
610-
)
611-
for column_name in seed_column_names:
612-
self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name)
613538
return self
614539

615540
def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None:
@@ -632,13 +557,22 @@ def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> No
632557
else:
633558
raise BuilderConfigurationError(f"🛑 Unsupported file type: {suffix}. Must be `.yaml`, `.yml` or `.json`.")
634559

560+
if (seed_config := self.get_seed_config()) is not None and isinstance(seed_config.config, DataFrameSeedConfig):
561+
logger.warning(
562+
"This builder was configured with a DataFrame seed dataset. "
563+
"DataFrame seeds cannot be serialized to config files. "
564+
"If you recreate this builder using `from_config`, you will need to re-run `with_seed_dataset`.\n"
565+
"Alternatively, consider writing your DataFrame to a file and re-running `with_seed_dataset` with "
566+
"a LocalFileSeedConfig so that you can share the config and data files."
567+
)
568+
635569
def get_builder_config(self) -> BuilderConfig:
636570
"""Get the builder config for the current Data Designer configuration.
637571
638572
Returns:
639573
The builder config.
640574
"""
641-
return BuilderConfig(data_designer=self.build(), datastore_settings=self._datastore_settings)
575+
return BuilderConfig(data_designer=self.build())
642576

643577
def __repr__(self) -> str:
644578
"""Generates a string representation of the DataDesignerConfigBuilder instance.
@@ -650,7 +584,7 @@ def __repr__(self) -> str:
650584
return f"{self.__class__.__name__}()"
651585

652586
props_to_repr = {
653-
"seed_dataset": (None if self._seed_config is None else f"'{self._seed_config.dataset}'"),
587+
"seed_dataset": (None if self._seed_config is None else f"{self._seed_config.config.seed_type} seed"),
654588
}
655589

656590
for column_type in get_column_display_order():

0 commit comments

Comments
 (0)