Skip to content

Commit 12301b3

Browse files
committed
update tests + test e2e
1 parent 6bd7760 commit 12301b3

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

src/data_designer/config/config_builder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import logging
88
from pathlib import Path
9-
9+
from typing import Union, Optional
1010
from pygments import highlight
1111
from pygments.formatters import HtmlFormatter
1212
from pygments.lexers import PythonLexer
@@ -31,6 +31,8 @@
3131
SamplingStrategy,
3232
SeedConfig,
3333
SeedDatasetReference,
34+
IndexRange,
35+
PartitionBlock,
3436
)
3537
from .utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
3638
from .utils.info import DataDesignerInfo
@@ -113,7 +115,7 @@ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self:
113115
datastore_settings=builder_config.datastore_settings,
114116
)
115117
builder.set_seed_datastore_settings(builder_config.datastore_settings)
116-
builder.with_seed_dataset(seed_dataset_reference, sampling_strategy=config.seed_config.sampling_strategy)
118+
builder.with_seed_dataset(seed_dataset_reference, sampling_strategy=config.seed_config.sampling_strategy, selection_strategy=config.seed_config.selection_strategy)
117119

118120
return builder
119121

@@ -493,6 +495,7 @@ def with_seed_dataset(
493495
dataset_reference: SeedDatasetReference,
494496
*,
495497
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
498+
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None,
496499
) -> Self:
497500
"""Add a seed dataset to the current Data Designer configuration.
498501
@@ -508,7 +511,7 @@ def with_seed_dataset(
508511
Returns:
509512
The current Data Designer config builder instance.
510513
"""
511-
self._seed_config = SeedConfig(dataset=dataset_reference.dataset, sampling_strategy=sampling_strategy)
514+
self._seed_config = SeedConfig(dataset=dataset_reference.dataset, sampling_strategy=sampling_strategy, selection_strategy=selection_strategy)
512515
self.set_seed_datastore_settings(
513516
dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None
514517
)

src/data_designer/engine/column_generators/generators/seed_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _initialize(self) -> None:
6060
self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
6161
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
6262
self._index_range = self._resolve_index_range()
63-
63+
6464
def _validate_selection_strategy(self) -> None:
6565
err_msg = None
6666
if self.config.selection_strategy is not None:
@@ -115,9 +115,8 @@ def _sample_records(self, num_records: int) -> pd.DataFrame:
115115
logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
116116
logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
117117
if self._index_range is not None:
118-
logger.info(f" |-- selection strategy: {self.config.selection_strategy.model_dump_json()}")
118+
logger.info(f" |-- selection strategy: {type(self.config.selection_strategy).__name__}\n{self.config.selection_strategy.model_dump_json(indent=4)}")
119119
logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
120-
121120
df_batch = pd.DataFrame()
122121
df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
123122
num_zero_record_responses = 0

src/data_designer/engine/dataset_builders/utils/config_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D
3535
columns=seed_column_configs,
3636
dataset=config.seed_config.dataset,
3737
sampling_strategy=config.seed_config.sampling_strategy,
38+
selection_strategy=config.seed_config.selection_strategy,
3839
)
3940
)
4041

tests/engine/column_generators/generators/test_seed_dataset.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,30 @@ def test_seed_dataset_column_generator_config_structure():
124124
assert config.columns[0].column_type.value == "seed-dataset"
125125
assert config.columns[1].name == "col2"
126126
assert config.columns[1].column_type.value == "seed-dataset"
127+
assert config.selection_strategy is None
128+
129+
# Test PartitionBlock selection strategy
130+
config = SeedDatasetMultiColumnConfig(
131+
columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")],
132+
dataset="test/dataset",
133+
sampling_strategy=SamplingStrategy.SHUFFLE,
134+
selection_strategy=PartitionBlock(partition_index=1, num_partitions=3),
135+
)
136+
assert isinstance(config.selection_strategy, PartitionBlock)
137+
assert config.selection_strategy.partition_index == 1
138+
assert config.selection_strategy.num_partitions == 3
139+
140+
# Test IndexRange selection strategy
141+
config = SeedDatasetMultiColumnConfig(
142+
columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")],
143+
dataset="test/dataset",
144+
sampling_strategy=SamplingStrategy.SHUFFLE,
145+
selection_strategy=IndexRange(start=0, end=1),
146+
)
147+
assert isinstance(config.selection_strategy, IndexRange)
148+
assert config.selection_strategy.start == 0
149+
assert config.selection_strategy.end == 1
150+
assert config.selection_strategy.size == 2
127151

128152
# Test constants and enum values
129153
assert MAX_ZERO_RECORD_RESPONSE_FACTOR == 2

0 commit comments

Comments
 (0)