Skip to content

Commit 57d500b

Browse files
committed
add support for IndexRange and PartitionBlock
1 parent 0d4b9a0 commit 57d500b

File tree

5 files changed

+240
-12
lines changed

5 files changed

+240
-12
lines changed

src/data_designer/config/seed.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,41 @@ class SamplingStrategy(str, Enum):
1919

2020

2121
class IndexRange(ConfigBase):
22-
start: int = Field(..., ge=0)
23-
end: int = Field(..., ge=1)
22+
start: int = Field(ge=0, description="The start index of the index range (inclusive)")
23+
end: int = Field(ge=0, description="The end index of the index range (inclusive)")
2424

2525
@model_validator(mode="after")
2626
def _validate_index_range(self) -> Self:
27-
if self.start >= self.end:
28-
raise ValueError("'start' index must be less than 'end' index")
27+
if self.start > self.end:
28+
raise ValueError("'start' index must be less than or equal to 'end' index")
2929
return self
3030

31+
@property
32+
def size(self) -> int:
33+
return self.end - self.start + 1
34+
3135

3236
class PartitionBlock(ConfigBase):
33-
partition_index: int = Field(..., default=0, ge=0)
34-
num_partitions: int = Field(..., default=1, ge=1)
37+
partition_index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
38+
num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")
3539

3640
@model_validator(mode="after")
3741
def _validate_partition_block(self) -> Self:
3842
if self.partition_index >= self.num_partitions:
3943
raise ValueError("'partition_index' must be less than 'num_partitions'")
4044
return self
4145

46+
def to_index_range(self, dataset_size: int) -> IndexRange:
47+
partition_size = dataset_size // self.num_partitions
48+
start = self.partition_index * partition_size
49+
50+
# For the last partition, extend to the end of the dataset to include remainder rows
51+
if self.partition_index == self.num_partitions - 1:
52+
end = dataset_size - 1
53+
else:
54+
end = ((self.partition_index + 1) * partition_size) - 1
55+
return IndexRange(start=start, end=end)
56+
4257

4358
class SeedConfig(ConfigBase):
4459
dataset: str

src/data_designer/engine/column_generators/generators/seed_dataset.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import duckdb
88
import pandas as pd
99

10-
from data_designer.config.seed import SamplingStrategy
10+
from data_designer.config.seed import SamplingStrategy, IndexRange, PartitionBlock
1111
from data_designer.engine.column_generators.generators.base import (
1212
FromScratchColumnGenerator,
1313
GenerationStrategy,
1414
GeneratorMetadata,
1515
)
16+
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
1617
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
1718
from data_designer.engine.processing.utils import concat_datasets
1819
from data_designer.engine.resources.resource_provider import ResourceType
@@ -58,18 +59,60 @@ def _initialize(self) -> None:
5859
self._df_remaining = None
5960
self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
6061
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
62+
self._index_range = self._resolve_index_range()
63+
64+
def _validate_selection_strategy(self) -> None:
65+
err_msg = None
66+
if self.config.selection_strategy is not None:
67+
if isinstance(self.config.selection_strategy, IndexRange) and self.config.selection_strategy.end >= self._seed_dataset_size:
68+
err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
69+
elif isinstance(self.config.selection_strategy, PartitionBlock) and self.config.selection_strategy.num_partitions > self._seed_dataset_size:
70+
err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
71+
if err_msg is not None:
72+
raise SeedDatasetError(err_msg)
73+
74+
def _resolve_index_range(self) -> IndexRange | None:
75+
self._validate_selection_strategy()
76+
index_range = None
77+
if self.config.selection_strategy is not None:
78+
if isinstance(self.config.selection_strategy, IndexRange):
79+
index_range = self.config.selection_strategy
80+
elif isinstance(self.config.selection_strategy, PartitionBlock):
81+
index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
82+
return index_range
6183

6284
def _reset_batch_reader(self, num_records: int) -> None:
6385
shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
6486
shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
65-
self._batch_reader = self.duckdb_conn.query(f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}").record_batch(
87+
88+
if self._index_range is not None:
89+
# Use subquery with row_number() window function to filter by index range
90+
# IndexRange uses 0-based indexing [start, end] inclusive, row_number() is 1-based
91+
# To convert 0-based index i to 1-based row_number: row_number = i + 1
92+
# For inclusive range [start, end], we want: row_number > start AND row_number <= end + 1
93+
# This gives us 1-based rows [start+1, end+1] which maps to 0-based indices [start, end]
94+
read_query = f"""
95+
SELECT * EXCLUDE (row_num) FROM (
96+
SELECT *, row_number() OVER () as row_num
97+
FROM '{self._dataset_uri}'
98+
) sub
99+
WHERE row_num > {self._index_range.start} AND row_num <= {self._index_range.end + 1}
100+
{shuffle_query}
101+
"""
102+
else:
103+
read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
104+
105+
self._batch_reader = self.duckdb_conn.query(read_query).record_batch(
66106
batch_size=num_records
67107
)
68108

69109
def _sample_records(self, num_records: int) -> pd.DataFrame:
70110
logger.info(f"🌱 Sampling {num_records} records from seed dataset")
71111
logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
72112
logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
113+
if self._index_range is not None:
114+
logger.info(f" |-- selection strategy: {self.config.selection_strategy.model_dump_json()}")
115+
logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
73116

74117
df_batch = pd.DataFrame()
75118
df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining

src/data_designer/engine/column_generators/utils/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ class PromptTemplateRenderError(DataDesignerError): ...
88

99

1010
class ExpressionTemplateRenderError(DataDesignerError): ...
11+
12+
13+
class SeedDatasetError(DataDesignerError): ...

tests/config/test_seed.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
3+
from data_designer.config.seed import IndexRange, PartitionBlock
4+
5+
def test_index_range_validation():
6+
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
7+
IndexRange(start=-1, end=10)
8+
9+
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
10+
IndexRange(start=0, end=-1)
11+
12+
with pytest.raises(ValueError, match="'start' index must be less than or equal to 'end' index"):
13+
IndexRange(start=11, end=10)
14+
15+
16+
def test_index_range_size():
17+
assert IndexRange(start=0, end=10).size == 11
18+
assert IndexRange(start=1, end=10).size == 10
19+
assert IndexRange(start=0, end=0).size == 1
20+
21+
22+
def test_partition_block_validation():
23+
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
24+
PartitionBlock(partition_index=-1, num_partitions=10)
25+
26+
with pytest.raises(ValueError, match="should be greater than or equal to 1"):
27+
PartitionBlock(partition_index=0, num_partitions=0)
28+
29+
with pytest.raises(ValueError, match="'partition_index' must be less than 'num_partitions'"):
30+
PartitionBlock(partition_index=10, num_partitions=10)
31+
32+
33+
def test_partition_block_to_index_range():
34+
index_range = PartitionBlock(partition_index=0, num_partitions=10).to_index_range(101)
35+
assert index_range.start == 0
36+
assert index_range.end == 9
37+
assert index_range.size == 10
38+
39+
index_range = PartitionBlock(partition_index=1, num_partitions=10).to_index_range(105)
40+
assert index_range.start == 10
41+
assert index_range.end == 19
42+
assert index_range.size == 10
43+
44+
index_range = PartitionBlock(partition_index=2, num_partitions=10).to_index_range(105)
45+
assert index_range.start == 20
46+
assert index_range.end == 29
47+
assert index_range.size == 10
48+
49+
index_range = PartitionBlock(partition_index=9, num_partitions=10).to_index_range(105)
50+
assert index_range.start == 90
51+
assert index_range.end == 104
52+
assert index_range.size == 15

tests/engine/column_generators/generators/test_seed_dataset.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
import pytest
1111

1212
from data_designer.config.columns import SeedDatasetColumnConfig
13-
from data_designer.config.seed import SamplingStrategy
13+
from data_designer.config.seed import SamplingStrategy, IndexRange, PartitionBlock
1414
from data_designer.engine.column_generators.generators.base import GenerationStrategy
1515
from data_designer.engine.column_generators.generators.seed_dataset import (
1616
MAX_ZERO_RECORD_RESPONSE_FACTOR,
1717
SeedDatasetColumnGenerator,
1818
)
1919
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
20-
from data_designer.engine.resources.resource_provider import ResourceType
20+
from data_designer.engine.resources.resource_provider import ResourceType, ResourceProvider
21+
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
2122

2223

2324
@pytest.fixture
@@ -333,7 +334,11 @@ def test_seed_dataset_column_generator_sample_records_multiple_batches(stub_seed
333334
# ============================================================================
334335

335336

336-
def create_generator_with_real_file(file_path: str, stub_resource_provider) -> SeedDatasetColumnGenerator:
337+
def create_generator_with_real_file(
338+
file_path: str,
339+
stub_resource_provider: ResourceProvider,
340+
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
341+
selection_strategy: IndexRange | PartitionBlock | None = None) -> SeedDatasetColumnGenerator:
337342
"""Helper function to create a generator with a real file and DuckDB connection."""
338343
config = SeedDatasetMultiColumnConfig(
339344
columns=[
@@ -344,7 +349,8 @@ def create_generator_with_real_file(file_path: str, stub_resource_provider) -> S
344349
SeedDatasetColumnConfig(name="score"),
345350
],
346351
dataset=f"test/{os.path.basename(file_path)}",
347-
sampling_strategy=SamplingStrategy.ORDERED,
352+
sampling_strategy=sampling_strategy,
353+
selection_strategy=selection_strategy,
348354
)
349355

350356
# Create a real DuckDB connection (in-memory by default)
@@ -605,3 +611,112 @@ def test_seed_dataset_generator_uses_real_duckdb_connection(fixture_name, stub_r
605611
# Verify the connection can execute count queries
606612
count_result = generator.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{file_path}'").fetchone()[0]
607613
assert count_result == 10
614+
615+
616+
# ============================================================================
617+
# Tests for SeedConfig selection strategies
618+
# ============================================================================
619+
@pytest.mark.parametrize(
620+
"fixture_name",
621+
[
622+
"seed_dataset_parquet",
623+
"seed_dataset_csv",
624+
"seed_dataset_json",
625+
"seed_dataset_jsonl",
626+
],
627+
)
628+
def test_seed_dataset_generator_index_range_selection_strategy(fixture_name, stub_resource_provider, request):
629+
"""Test that generator correctly applies index range selection strategy."""
630+
# Ordered Sampling
631+
632+
# Range with a subset of items
633+
file_path = request.getfixturevalue(fixture_name)
634+
generator = create_generator_with_real_file(file_path, stub_resource_provider, sampling_strategy=SamplingStrategy.ORDERED, selection_strategy=IndexRange(start=4, end=8))
635+
result = generator.generate_from_scratch(6)
636+
assert len(result) == 6
637+
assert list(result["name"]) == ["Eve", "Frank", "Grace", "Henry", "Ivy", "Eve"]
638+
639+
# Range with just one item
640+
generator = create_generator_with_real_file(file_path, stub_resource_provider, sampling_strategy=SamplingStrategy.ORDERED, selection_strategy=IndexRange(start=4, end=4))
641+
result = generator.generate_from_scratch(1)
642+
assert len(result) == 1
643+
assert list(result["name"]) == ["Eve"]
644+
645+
# Range with all items
646+
generator = create_generator_with_real_file(file_path, stub_resource_provider, sampling_strategy=SamplingStrategy.ORDERED, selection_strategy=IndexRange(start=0, end=9))
647+
result = generator.generate_from_scratch(10)
648+
assert len(result) == 10
649+
assert list(result["name"]) == ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack"]
650+
651+
# Shuffle Sampling
652+
653+
# Range with a subset of items
654+
generator = create_generator_with_real_file(file_path, stub_resource_provider, sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=IndexRange(start=4, end=8))
655+
result = generator.generate_from_scratch(10)
656+
assert len(result) == 10
657+
assert set(result["name"]).issubset({"Eve", "Frank", "Grace", "Henry", "Ivy"})
658+
659+
# Range with just one item
660+
generator = create_generator_with_real_file(file_path, stub_resource_provider, sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=IndexRange(start=4, end=4))
661+
result = generator.generate_from_scratch(1)
662+
assert len(result) == 1
663+
assert list(result["name"]) == ["Eve"]
664+
665+
# Range with all items
666+
generator = create_generator_with_real_file(file_path, stub_resource_provider, sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=IndexRange(start=0, end=9))
667+
result = generator.generate_from_scratch(10)
668+
assert len(result) == 10
669+
assert set(result["name"]).issubset({"Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack"})
670+
671+
672+
@pytest.mark.parametrize(
673+
"fixture_name",
674+
[
675+
"seed_dataset_parquet",
676+
"seed_dataset_csv",
677+
"seed_dataset_json",
678+
"seed_dataset_jsonl",
679+
],
680+
)
681+
def test_seed_dataset_generator_partition_block_selection_strategy(fixture_name, stub_resource_provider, request):
682+
"""Test that generator correctly applies partition block selection strategy."""
683+
file_path = request.getfixturevalue(fixture_name)
684+
generator = create_generator_with_real_file(
685+
file_path,
686+
stub_resource_provider,
687+
sampling_strategy=SamplingStrategy.ORDERED,
688+
selection_strategy=PartitionBlock(partition_index=1, num_partitions=3)
689+
)
690+
result = generator.generate_from_scratch(5)
691+
assert len(result) == 5
692+
# Requesting 5 items from a 3-item partition should cycle:
693+
assert list(result["name"]) == ["David", "Eve", "Frank", "David", "Eve"]
694+
695+
generator = create_generator_with_real_file(
696+
file_path,
697+
stub_resource_provider,
698+
sampling_strategy=SamplingStrategy.SHUFFLE,
699+
selection_strategy=PartitionBlock(partition_index=4, num_partitions=5))
700+
result = generator.generate_from_scratch(10)
701+
assert len(result) == 10
702+
assert set(result["name"]).issubset({"Jack", "Ivy"})
703+
704+
705+
@pytest.mark.parametrize(
706+
"fixture_name",
707+
[
708+
"seed_dataset_parquet",
709+
"seed_dataset_csv",
710+
"seed_dataset_json",
711+
"seed_dataset_jsonl",
712+
],
713+
)
714+
def test_seed_dataset_generator_invalid_selection_strategies(fixture_name, stub_resource_provider, request):
715+
"""Test that generator raises an error for invalid selection strategies."""
716+
file_path = request.getfixturevalue(fixture_name)
717+
with pytest.raises(SeedDatasetError, match="Selection strategy 'end' index 10 is out of bounds for dataset size 10"):
718+
generator = create_generator_with_real_file(file_path, stub_resource_provider, selection_strategy=IndexRange(start=1, end=10))
719+
generator.generate_from_scratch(1)
720+
with pytest.raises(SeedDatasetError, match="Selection strategy 'num_partitions' 11 is out of bounds for dataset size 10"):
721+
generator = create_generator_with_real_file(file_path, stub_resource_provider, selection_strategy=PartitionBlock(partition_index=0, num_partitions=11))
722+
generator.generate_from_scratch(1)

0 commit comments

Comments
 (0)