Skip to content

Commit 7268290

Browse files
authored
feat: support IndexRange and PartitionBlock seed selection strategy (#8)
Support `IndexRange` and `PartitionBlock` seed selection strategy
1 parent 7c88230 commit 7268290

File tree

8 files changed

+406
-13
lines changed

8 files changed

+406
-13
lines changed

src/data_designer/config/config_builder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
)
3131
from .seed import (
3232
DatastoreSeedDatasetReference,
33+
IndexRange,
3334
LocalSeedDatasetReference,
35+
PartitionBlock,
3436
SamplingStrategy,
3537
SeedConfig,
3638
SeedDatasetReference,
@@ -116,7 +118,11 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:
116118
datastore_settings=builder_config.datastore_settings,
117119
)
118120
builder.set_seed_datastore_settings(builder_config.datastore_settings)
119-
builder.with_seed_dataset(seed_dataset_reference, sampling_strategy=config.seed_config.sampling_strategy)
121+
builder.with_seed_dataset(
122+
seed_dataset_reference,
123+
sampling_strategy=config.seed_config.sampling_strategy,
124+
selection_strategy=config.seed_config.selection_strategy,
125+
)
120126

121127
return builder
122128

@@ -545,6 +551,7 @@ def with_seed_dataset(
545551
dataset_reference: SeedDatasetReference,
546552
*,
547553
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
554+
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None,
548555
) -> Self:
549556
"""Add a seed dataset to the current Data Designer configuration.
550557
@@ -560,7 +567,11 @@ def with_seed_dataset(
560567
Returns:
561568
The current Data Designer config builder instance.
562569
"""
563-
self._seed_config = SeedConfig(dataset=dataset_reference.dataset, sampling_strategy=sampling_strategy)
570+
self._seed_config = SeedConfig(
571+
dataset=dataset_reference.dataset,
572+
sampling_strategy=sampling_strategy,
573+
selection_strategy=selection_strategy,
574+
)
564575
self.set_seed_datastore_settings(
565576
dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None
566577
)

src/data_designer/config/seed.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
from abc import ABC
55
from enum import Enum
6+
from typing import Optional, Union
67

7-
from pydantic import field_validator
8+
from pydantic import Field, field_validator, model_validator
9+
from typing_extensions import Self
810

911
from .base import ConfigBase
1012
from .datastore import DatastoreSettings
@@ -16,9 +18,97 @@ class SamplingStrategy(str, Enum):
1618
SHUFFLE = "shuffle"
1719

1820

21+
class IndexRange(ConfigBase):
22+
start: int = Field(ge=0, description="The start index of the index range (inclusive)")
23+
end: int = Field(ge=0, description="The end index of the index range (inclusive)")
24+
25+
@model_validator(mode="after")
26+
def _validate_index_range(self) -> Self:
27+
if self.start > self.end:
28+
raise ValueError("'start' index must be less than or equal to 'end' index")
29+
return self
30+
31+
@property
32+
def size(self) -> int:
33+
return self.end - self.start + 1
34+
35+
36+
class PartitionBlock(ConfigBase):
37+
index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
38+
num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")
39+
40+
@model_validator(mode="after")
41+
def _validate_partition_block(self) -> Self:
42+
if self.index >= self.num_partitions:
43+
raise ValueError("'index' must be less than 'num_partitions'")
44+
return self
45+
46+
def to_index_range(self, dataset_size: int) -> IndexRange:
47+
partition_size = dataset_size // self.num_partitions
48+
start = self.index * partition_size
49+
50+
# For the last partition, extend to the end of the dataset to include remainder rows
51+
if self.index == self.num_partitions - 1:
52+
end = dataset_size - 1
53+
else:
54+
end = ((self.index + 1) * partition_size) - 1
55+
return IndexRange(start=start, end=end)
56+
57+
1958
class SeedConfig(ConfigBase):
59+
"""Configuration for sampling data from a seed dataset.
60+
61+
Args:
62+
dataset: Path or identifier for the seed dataset.
63+
sampling_strategy: Strategy for how to sample rows from the dataset.
64+
- ORDERED: Read rows sequentially in their original order.
65+
- SHUFFLE: Randomly shuffle rows before sampling. When used with
66+
selection_strategy, shuffling occurs within the selected range/partition.
67+
selection_strategy: Optional strategy to select a subset of the dataset.
68+
- IndexRange: Select a specific range of indices (e.g., rows 100-200).
69+
- PartitionBlock: Select a partition by splitting the dataset into N equal parts.
70+
Partition indices are zero-based (index=0 is the first partition, index=1 is
71+
the second, etc.).
72+
73+
Examples:
74+
Read rows sequentially from start to end:
75+
SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.ORDERED)
76+
77+
Read rows in random order:
78+
SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.SHUFFLE)
79+
80+
Read specific index range (rows 100-199):
81+
SeedConfig(
82+
dataset="my_data.parquet",
83+
sampling_strategy=SamplingStrategy.ORDERED,
84+
selection_strategy=IndexRange(start=100, end=199)
85+
)
86+
87+
Read random rows from a specific index range (shuffles within rows 100-199):
88+
SeedConfig(
89+
dataset="my_data.parquet",
90+
sampling_strategy=SamplingStrategy.SHUFFLE,
91+
selection_strategy=IndexRange(start=100, end=199)
92+
)
93+
94+
Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
95+
SeedConfig(
96+
dataset="my_data.parquet",
97+
sampling_strategy=SamplingStrategy.ORDERED,
98+
selection_strategy=PartitionBlock(index=2, num_partitions=5)
99+
)
100+
101+
Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
102+
SeedConfig(
103+
dataset="my_data.parquet",
104+
sampling_strategy=SamplingStrategy.SHUFFLE,
105+
selection_strategy=PartitionBlock(index=0, num_partitions=10)
106+
)
107+
"""
108+
20109
dataset: str
21110
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
111+
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None
22112

23113

24114
class SeedDatasetReference(ABC, ConfigBase):

src/data_designer/engine/column_generators/generators/seed_dataset.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import duckdb
88
import pandas as pd
99

10-
from data_designer.config.seed import SamplingStrategy
10+
from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
1111
from data_designer.engine.column_generators.generators.base import (
1212
FromScratchColumnGenerator,
1313
GenerationStrategy,
1414
GeneratorMetadata,
1515
)
16+
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
1617
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
1718
from data_designer.engine.processing.utils import concat_datasets
1819
from data_designer.engine.resources.resource_provider import ResourceType
@@ -58,19 +59,67 @@ def _initialize(self) -> None:
5859
self._df_remaining = None
5960
self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
6061
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
62+
self._index_range = self._resolve_index_range()
63+
64+
def _validate_selection_strategy(self) -> None:
65+
err_msg = None
66+
if self.config.selection_strategy is not None:
67+
if (
68+
isinstance(self.config.selection_strategy, IndexRange)
69+
and self.config.selection_strategy.end >= self._seed_dataset_size
70+
):
71+
err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
72+
elif (
73+
isinstance(self.config.selection_strategy, PartitionBlock)
74+
and self.config.selection_strategy.num_partitions > self._seed_dataset_size
75+
):
76+
err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
77+
if err_msg is not None:
78+
raise SeedDatasetError(err_msg)
79+
80+
def _resolve_index_range(self) -> IndexRange | None:
81+
self._validate_selection_strategy()
82+
index_range = None
83+
if self.config.selection_strategy is not None:
84+
if isinstance(self.config.selection_strategy, IndexRange):
85+
index_range = self.config.selection_strategy
86+
elif isinstance(self.config.selection_strategy, PartitionBlock):
87+
index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
88+
return index_range
6189

6290
def _reset_batch_reader(self, num_records: int) -> None:
6391
shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
6492
shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
65-
self._batch_reader = self.duckdb_conn.query(f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}").record_batch(
66-
batch_size=num_records
67-
)
93+
94+
if self._index_range is not None:
95+
# Use LIMIT and OFFSET for efficient index range filtering
96+
# IndexRange uses 0-based indexing [start, end] inclusive
97+
# OFFSET skips the first 'start' rows (0-based)
98+
# LIMIT takes 'end - start + 1' rows to include both start and end (inclusive)
99+
offset_value = self._index_range.start
100+
limit_value = self._index_range.end - self._index_range.start + 1
101+
read_query = f"""
102+
SELECT * FROM '{self._dataset_uri}'
103+
LIMIT {limit_value} OFFSET {offset_value}
104+
"""
105+
106+
read_query = f"SELECT * FROM ({read_query}){shuffle_query}"
107+
else:
108+
read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
109+
self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records)
68110

69111
def _sample_records(self, num_records: int) -> pd.DataFrame:
70112
logger.info(f"🌱 Sampling {num_records} records from seed dataset")
71113
logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
72114
logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
73-
115+
if self._index_range is not None:
116+
if isinstance(self.config.selection_strategy, IndexRange):
117+
logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive")
118+
else:
119+
logger.info(
120+
f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}"
121+
)
122+
logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
74123
df_batch = pd.DataFrame()
75124
df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
76125
num_zero_record_responses = 0

src/data_designer/engine/column_generators/utils/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ class PromptTemplateRenderError(DataDesignerError): ...
88

99

1010
class ExpressionTemplateRenderError(DataDesignerError): ...
11+
12+
13+
class SeedDatasetError(DataDesignerError): ...

src/data_designer/engine/dataset_builders/utils/config_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D
3636
columns=seed_column_configs,
3737
dataset=config.seed_config.dataset,
3838
sampling_strategy=config.seed_config.sampling_strategy,
39+
selection_strategy=config.seed_config.selection_strategy,
3940
)
4041
)
4142

src/data_designer/essentials/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
UniformSamplerParams,
5050
UUIDSamplerParams,
5151
)
52-
from ..config.seed import DatastoreSeedDatasetReference, SamplingStrategy, SeedConfig
52+
from ..config.seed import DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig
5353
from ..config.utils.code_lang import CodeLang
5454
from ..config.utils.misc import can_run_data_designer_locally
5555
from ..config.validator_params import (
@@ -89,6 +89,7 @@
8989
"DropColumnsProcessorConfig",
9090
"ExpressionColumnConfig",
9191
"GaussianSamplerParams",
92+
"IndexRange",
9293
"ImageContext",
9394
"ImageFormat",
9495
"InferenceParameters",
@@ -104,6 +105,7 @@
104105
"ModalityContext",
105106
"ModalityDataType",
106107
"ModelConfig",
108+
"PartitionBlock",
107109
"PersonSamplerParams",
108110
"PoissonSamplerParams",
109111
"ProcessorType",

tests/config/test_seed.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
6+
from data_designer.config.seed import IndexRange, PartitionBlock
7+
8+
9+
def test_index_range_validation():
10+
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
11+
IndexRange(start=-1, end=10)
12+
13+
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
14+
IndexRange(start=0, end=-1)
15+
16+
with pytest.raises(ValueError, match="'start' index must be less than or equal to 'end' index"):
17+
IndexRange(start=11, end=10)
18+
19+
20+
def test_index_range_size():
21+
assert IndexRange(start=0, end=10).size == 11
22+
assert IndexRange(start=1, end=10).size == 10
23+
assert IndexRange(start=0, end=0).size == 1
24+
25+
26+
def test_partition_block_validation():
27+
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
28+
PartitionBlock(index=-1, num_partitions=10)
29+
30+
with pytest.raises(ValueError, match="should be greater than or equal to 1"):
31+
PartitionBlock(index=0, num_partitions=0)
32+
33+
with pytest.raises(ValueError, match="'index' must be less than 'num_partitions'"):
34+
PartitionBlock(index=10, num_partitions=10)
35+
36+
37+
def test_partition_block_to_index_range():
38+
index_range = PartitionBlock(index=0, num_partitions=10).to_index_range(101)
39+
assert index_range.start == 0
40+
assert index_range.end == 9
41+
assert index_range.size == 10
42+
43+
index_range = PartitionBlock(index=1, num_partitions=10).to_index_range(105)
44+
assert index_range.start == 10
45+
assert index_range.end == 19
46+
assert index_range.size == 10
47+
48+
index_range = PartitionBlock(index=2, num_partitions=10).to_index_range(105)
49+
assert index_range.start == 20
50+
assert index_range.end == 29
51+
assert index_range.size == 10
52+
53+
index_range = PartitionBlock(index=9, num_partitions=10).to_index_range(105)
54+
assert index_range.start == 90
55+
assert index_range.end == 104
56+
assert index_range.size == 15

0 commit comments

Comments
 (0)