Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
)
from .seed import (
DatastoreSeedDatasetReference,
IndexRange,
LocalSeedDatasetReference,
PartitionBlock,
SamplingStrategy,
SeedConfig,
SeedDatasetReference,
Expand Down Expand Up @@ -116,7 +118,11 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:
datastore_settings=builder_config.datastore_settings,
)
builder.set_seed_datastore_settings(builder_config.datastore_settings)
builder.with_seed_dataset(seed_dataset_reference, sampling_strategy=config.seed_config.sampling_strategy)
builder.with_seed_dataset(
seed_dataset_reference,
sampling_strategy=config.seed_config.sampling_strategy,
selection_strategy=config.seed_config.selection_strategy,
)

return builder

Expand Down Expand Up @@ -545,6 +551,7 @@ def with_seed_dataset(
dataset_reference: SeedDatasetReference,
*,
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None,
) -> Self:
"""Add a seed dataset to the current Data Designer configuration.

Expand All @@ -560,7 +567,11 @@ def with_seed_dataset(
Returns:
The current Data Designer config builder instance.
"""
self._seed_config = SeedConfig(dataset=dataset_reference.dataset, sampling_strategy=sampling_strategy)
self._seed_config = SeedConfig(
dataset=dataset_reference.dataset,
sampling_strategy=sampling_strategy,
selection_strategy=selection_strategy,
)
self.set_seed_datastore_settings(
dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None
)
Expand Down
92 changes: 91 additions & 1 deletion src/data_designer/config/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from abc import ABC
from enum import Enum
from typing import Optional, Union

from pydantic import field_validator
from pydantic import Field, field_validator, model_validator
from typing_extensions import Self

from .base import ConfigBase
from .datastore import DatastoreSettings
Expand All @@ -16,9 +18,97 @@ class SamplingStrategy(str, Enum):
SHUFFLE = "shuffle"


class IndexRange(ConfigBase):
start: int = Field(ge=0, description="The start index of the index range (inclusive)")
end: int = Field(ge=0, description="The end index of the index range (inclusive)")

@model_validator(mode="after")
def _validate_index_range(self) -> Self:
if self.start > self.end:
raise ValueError("'start' index must be less than or equal to 'end' index")
return self

@property
def size(self) -> int:
return self.end - self.start + 1


class PartitionBlock(ConfigBase):
index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")

@model_validator(mode="after")
def _validate_partition_block(self) -> Self:
if self.index >= self.num_partitions:
raise ValueError("'index' must be less than 'num_partitions'")
return self

def to_index_range(self, dataset_size: int) -> IndexRange:
partition_size = dataset_size // self.num_partitions
start = self.index * partition_size

# For the last partition, extend to the end of the dataset to include remainder rows
if self.index == self.num_partitions - 1:
end = dataset_size - 1
else:
end = ((self.index + 1) * partition_size) - 1
return IndexRange(start=start, end=end)


class SeedConfig(ConfigBase):
"""Configuration for sampling data from a seed dataset.

Args:
dataset: Path or identifier for the seed dataset.
sampling_strategy: Strategy for how to sample rows from the dataset.
- ORDERED: Read rows sequentially in their original order.
- SHUFFLE: Randomly shuffle rows before sampling. When used with
selection_strategy, shuffling occurs within the selected range/partition.
selection_strategy: Optional strategy to select a subset of the dataset.
- IndexRange: Select a specific range of indices (e.g., rows 100-200).
- PartitionBlock: Select a partition by splitting the dataset into N equal parts.
Partition indices are zero-based (index=0 is the first partition, index=1 is
the second, etc.).

Examples:
Read rows sequentially from start to end:
SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.ORDERED)

Read rows in random order:
SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.SHUFFLE)

Read specific index range (rows 100-199):
SeedConfig(
dataset="my_data.parquet",
sampling_strategy=SamplingStrategy.ORDERED,
selection_strategy=IndexRange(start=100, end=199)
)

Read random rows from a specific index range (shuffles within rows 100-199):
SeedConfig(
dataset="my_data.parquet",
sampling_strategy=SamplingStrategy.SHUFFLE,
selection_strategy=IndexRange(start=100, end=199)
)

Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
SeedConfig(
dataset="my_data.parquet",
sampling_strategy=SamplingStrategy.ORDERED,
selection_strategy=PartitionBlock(index=2, num_partitions=5)
)

Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
SeedConfig(
dataset="my_data.parquet",
sampling_strategy=SamplingStrategy.SHUFFLE,
selection_strategy=PartitionBlock(index=0, num_partitions=10)
)
"""

dataset: str
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None


class SeedDatasetReference(ABC, ConfigBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import duckdb
import pandas as pd

from data_designer.config.seed import SamplingStrategy
from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
from data_designer.engine.column_generators.generators.base import (
FromScratchColumnGenerator,
GenerationStrategy,
GeneratorMetadata,
)
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
from data_designer.engine.processing.utils import concat_datasets
from data_designer.engine.resources.resource_provider import ResourceType
Expand Down Expand Up @@ -58,19 +59,67 @@ def _initialize(self) -> None:
self._df_remaining = None
self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
self._index_range = self._resolve_index_range()

def _validate_selection_strategy(self) -> None:
err_msg = None
if self.config.selection_strategy is not None:
if (
isinstance(self.config.selection_strategy, IndexRange)
and self.config.selection_strategy.end >= self._seed_dataset_size
):
err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
elif (
isinstance(self.config.selection_strategy, PartitionBlock)
and self.config.selection_strategy.num_partitions > self._seed_dataset_size
):
err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
if err_msg is not None:
raise SeedDatasetError(err_msg)

def _resolve_index_range(self) -> IndexRange | None:
self._validate_selection_strategy()
index_range = None
if self.config.selection_strategy is not None:
if isinstance(self.config.selection_strategy, IndexRange):
index_range = self.config.selection_strategy
elif isinstance(self.config.selection_strategy, PartitionBlock):
index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
return index_range

def _reset_batch_reader(self, num_records: int) -> None:
shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
self._batch_reader = self.duckdb_conn.query(f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}").record_batch(
batch_size=num_records
)

if self._index_range is not None:
# Use LIMIT and OFFSET for efficient index range filtering
# IndexRange uses 0-based indexing [start, end] inclusive
# OFFSET skips the first 'start' rows (0-based)
# LIMIT takes 'end - start + 1' rows to include both start and end (inclusive)
offset_value = self._index_range.start
limit_value = self._index_range.end - self._index_range.start + 1
read_query = f"""
SELECT * FROM '{self._dataset_uri}'
LIMIT {limit_value} OFFSET {offset_value}
"""

read_query = f"SELECT * FROM ({read_query}){shuffle_query}"
else:
read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records)

def _sample_records(self, num_records: int) -> pd.DataFrame:
logger.info(f"🌱 Sampling {num_records} records from seed dataset")
logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")

if self._index_range is not None:
if isinstance(self.config.selection_strategy, IndexRange):
logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive")
else:
logger.info(
f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}"
)
logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
df_batch = pd.DataFrame()
df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
num_zero_record_responses = 0
Expand Down
3 changes: 3 additions & 0 deletions src/data_designer/engine/column_generators/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class PromptTemplateRenderError(DataDesignerError): ...


class ExpressionTemplateRenderError(DataDesignerError): ...


class SeedDatasetError(DataDesignerError): ...
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D
columns=seed_column_configs,
dataset=config.seed_config.dataset,
sampling_strategy=config.seed_config.sampling_strategy,
selection_strategy=config.seed_config.selection_strategy,
)
)

Expand Down
4 changes: 3 additions & 1 deletion src/data_designer/essentials/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
UniformSamplerParams,
UUIDSamplerParams,
)
from ..config.seed import DatastoreSeedDatasetReference, SamplingStrategy, SeedConfig
from ..config.seed import DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig
from ..config.utils.code_lang import CodeLang
from ..config.utils.misc import can_run_data_designer_locally
from ..config.validator_params import (
Expand Down Expand Up @@ -89,6 +89,7 @@
"DropColumnsProcessorConfig",
"ExpressionColumnConfig",
"GaussianSamplerParams",
"IndexRange",
"ImageContext",
"ImageFormat",
"InferenceParameters",
Expand All @@ -104,6 +105,7 @@
"ModalityContext",
"ModalityDataType",
"ModelConfig",
"PartitionBlock",
"PersonSamplerParams",
"PoissonSamplerParams",
"ProcessorType",
Expand Down
56 changes: 56 additions & 0 deletions tests/config/test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import pytest

from data_designer.config.seed import IndexRange, PartitionBlock


def test_index_range_validation():
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
IndexRange(start=-1, end=10)

with pytest.raises(ValueError, match="should be greater than or equal to 0"):
IndexRange(start=0, end=-1)

with pytest.raises(ValueError, match="'start' index must be less than or equal to 'end' index"):
IndexRange(start=11, end=10)


def test_index_range_size():
assert IndexRange(start=0, end=10).size == 11
assert IndexRange(start=1, end=10).size == 10
assert IndexRange(start=0, end=0).size == 1


def test_partition_block_validation():
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
PartitionBlock(index=-1, num_partitions=10)

with pytest.raises(ValueError, match="should be greater than or equal to 1"):
PartitionBlock(index=0, num_partitions=0)

with pytest.raises(ValueError, match="'index' must be less than 'num_partitions'"):
PartitionBlock(index=10, num_partitions=10)


def test_partition_block_to_index_range():
index_range = PartitionBlock(index=0, num_partitions=10).to_index_range(101)
assert index_range.start == 0
assert index_range.end == 9
assert index_range.size == 10

index_range = PartitionBlock(index=1, num_partitions=10).to_index_range(105)
assert index_range.start == 10
assert index_range.end == 19
assert index_range.size == 10

index_range = PartitionBlock(index=2, num_partitions=10).to_index_range(105)
assert index_range.start == 20
assert index_range.end == 29
assert index_range.size == 10

index_range = PartitionBlock(index=9, num_partitions=10).to_index_range(105)
assert index_range.start == 90
assert index_range.end == 104
assert index_range.size == 15
Loading