diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index a33b3658..8bfe95e6 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -30,7 +30,9 @@ ) from .seed import ( DatastoreSeedDatasetReference, + IndexRange, LocalSeedDatasetReference, + PartitionBlock, SamplingStrategy, SeedConfig, SeedDatasetReference, @@ -116,7 +118,11 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self: datastore_settings=builder_config.datastore_settings, ) builder.set_seed_datastore_settings(builder_config.datastore_settings) - builder.with_seed_dataset(seed_dataset_reference, sampling_strategy=config.seed_config.sampling_strategy) + builder.with_seed_dataset( + seed_dataset_reference, + sampling_strategy=config.seed_config.sampling_strategy, + selection_strategy=config.seed_config.selection_strategy, + ) return builder @@ -545,6 +551,7 @@ def with_seed_dataset( dataset_reference: SeedDatasetReference, *, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, + selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None, ) -> Self: """Add a seed dataset to the current Data Designer configuration. @@ -560,7 +567,11 @@ def with_seed_dataset( Returns: The current Data Designer config builder instance. """ - self._seed_config = SeedConfig(dataset=dataset_reference.dataset, sampling_strategy=sampling_strategy) + self._seed_config = SeedConfig( + dataset=dataset_reference.dataset, + sampling_strategy=sampling_strategy, + selection_strategy=selection_strategy, + ) self.set_seed_datastore_settings( dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None ) diff --git a/src/data_designer/config/seed.py b/src/data_designer/config/seed.py index 9f7480eb..a467b98a 100644 --- a/src/data_designer/config/seed.py +++ b/src/data_designer/config/seed.py @@ -3,8 +3,10 @@ from abc import ABC from enum import Enum +from typing import Optional, Union -from pydantic import field_validator +from pydantic import Field, field_validator, model_validator +from typing_extensions import Self from .base import ConfigBase from .datastore import DatastoreSettings @@ -16,9 +18,97 @@ class SamplingStrategy(str, Enum): SHUFFLE = "shuffle" +class IndexRange(ConfigBase): + start: int = Field(ge=0, description="The start index of the index range (inclusive)") + end: int = Field(ge=0, description="The end index of the index range (inclusive)") + + @model_validator(mode="after") + def _validate_index_range(self) -> Self: + if self.start > self.end: + raise ValueError("'start' index must be less than or equal to 'end' index") + return self + + @property + def size(self) -> int: + return self.end - self.start + 1 + + +class PartitionBlock(ConfigBase): + index: int = Field(default=0, ge=0, description="The index of the partition to sample from") + num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset") + + @model_validator(mode="after") + def _validate_partition_block(self) -> Self: + if self.index >= self.num_partitions: + raise ValueError("'index' must be less than 'num_partitions'") + return self + + def to_index_range(self, dataset_size: int) -> IndexRange: + partition_size = dataset_size // self.num_partitions + start = self.index * partition_size + + # For the last partition, extend to the end of the dataset to include remainder rows + if self.index == self.num_partitions - 1: + end = dataset_size - 1 + else: + end = ((self.index + 1) * partition_size) - 1 + return IndexRange(start=start, end=end) + + class SeedConfig(ConfigBase): + """Configuration for sampling data from a seed dataset. + + Args: + dataset: Path or identifier for the seed dataset. + sampling_strategy: Strategy for how to sample rows from the dataset. + - ORDERED: Read rows sequentially in their original order. + - SHUFFLE: Randomly shuffle rows before sampling. When used with + selection_strategy, shuffling occurs within the selected range/partition. + selection_strategy: Optional strategy to select a subset of the dataset. + - IndexRange: Select a specific range of indices (e.g., rows 100-200). + - PartitionBlock: Select a partition by splitting the dataset into N equal parts. + Partition indices are zero-based (index=0 is the first partition, index=1 is + the second, etc.). + + Examples: + Read rows sequentially from start to end: + SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.ORDERED) + + Read rows in random order: + SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.SHUFFLE) + + Read specific index range (rows 100-199): + SeedConfig( + dataset="my_data.parquet", + sampling_strategy=SamplingStrategy.ORDERED, + selection_strategy=IndexRange(start=100, end=199) + ) + + Read random rows from a specific index range (shuffles within rows 100-199): + SeedConfig( + dataset="my_data.parquet", + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=IndexRange(start=100, end=199) + ) + + Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset): + SeedConfig( + dataset="my_data.parquet", + sampling_strategy=SamplingStrategy.ORDERED, + selection_strategy=PartitionBlock(index=2, num_partitions=5) + ) + + Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition): + SeedConfig( + dataset="my_data.parquet", + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=PartitionBlock(index=0, num_partitions=10) + ) + """ + dataset: str sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED + selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None class SeedDatasetReference(ABC, ConfigBase): diff --git a/src/data_designer/engine/column_generators/generators/seed_dataset.py b/src/data_designer/engine/column_generators/generators/seed_dataset.py index ffec22be..35578602 100644 --- a/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -7,12 +7,13 @@ import duckdb import pandas as pd -from data_designer.config.seed import SamplingStrategy +from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy from data_designer.engine.column_generators.generators.base import ( FromScratchColumnGenerator, GenerationStrategy, GeneratorMetadata, ) +from data_designer.engine.column_generators.utils.errors import SeedDatasetError from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig from data_designer.engine.processing.utils import concat_datasets from data_designer.engine.resources.resource_provider import ResourceType @@ -58,19 +59,67 @@ def _initialize(self) -> None: self._df_remaining = None self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset) self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0] + self._index_range = self._resolve_index_range() + + def _validate_selection_strategy(self) -> None: + err_msg = None + if self.config.selection_strategy is not None: + if ( + isinstance(self.config.selection_strategy, IndexRange) + and self.config.selection_strategy.end >= self._seed_dataset_size + ): + err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}" + elif ( + isinstance(self.config.selection_strategy, PartitionBlock) + and self.config.selection_strategy.num_partitions > self._seed_dataset_size + ): + err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}" + if err_msg is not None: + raise SeedDatasetError(err_msg) + + def _resolve_index_range(self) -> IndexRange | None: + self._validate_selection_strategy() + index_range = None + if self.config.selection_strategy is not None: + if isinstance(self.config.selection_strategy, IndexRange): + index_range = self.config.selection_strategy + elif isinstance(self.config.selection_strategy, PartitionBlock): + index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size) + return index_range def _reset_batch_reader(self, num_records: int) -> None: shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE shuffle_query = " ORDER BY RANDOM()" if shuffle else "" - self._batch_reader = self.duckdb_conn.query(f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}").record_batch( - batch_size=num_records - ) + + if self._index_range is not None: + # Use LIMIT and OFFSET for efficient index range filtering + # IndexRange uses 0-based indexing [start, end] inclusive + # OFFSET skips the first 'start' rows (0-based) + # LIMIT takes 'end - start + 1' rows to include both start and end (inclusive) + offset_value = self._index_range.start + limit_value = self._index_range.end - self._index_range.start + 1 + read_query = f""" + SELECT * FROM '{self._dataset_uri}' + LIMIT {limit_value} OFFSET {offset_value} + """ + + read_query = f"SELECT * FROM ({read_query}){shuffle_query}" + else: + read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}" + self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records) def _sample_records(self, num_records: int) -> pd.DataFrame: logger.info(f"🌱 Sampling {num_records} records from seed dataset") logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records") logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}") - + if self._index_range is not None: + if isinstance(self.config.selection_strategy, IndexRange): + logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive") + else: + logger.info( + f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}" + ) + logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records") df_batch = pd.DataFrame() df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining num_zero_record_responses = 0 diff --git a/src/data_designer/engine/column_generators/utils/errors.py b/src/data_designer/engine/column_generators/utils/errors.py index 7820406d..f467d862 100644 --- a/src/data_designer/engine/column_generators/utils/errors.py +++ b/src/data_designer/engine/column_generators/utils/errors.py @@ -8,3 +8,6 @@ class PromptTemplateRenderError(DataDesignerError): ... class ExpressionTemplateRenderError(DataDesignerError): ... + + +class SeedDatasetError(DataDesignerError): ... diff --git a/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/src/data_designer/engine/dataset_builders/utils/config_compiler.py index 2a784212..d80ec37a 100644 --- a/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -36,6 +36,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D columns=seed_column_configs, dataset=config.seed_config.dataset, sampling_strategy=config.seed_config.sampling_strategy, + selection_strategy=config.seed_config.selection_strategy, ) ) diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 3be1950f..3b3fd96f 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -49,7 +49,7 @@ UniformSamplerParams, UUIDSamplerParams, ) -from ..config.seed import DatastoreSeedDatasetReference, SamplingStrategy, SeedConfig +from ..config.seed import DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig from ..config.utils.code_lang import CodeLang from ..config.utils.misc import can_run_data_designer_locally from ..config.validator_params import ( @@ -89,6 +89,7 @@ "DropColumnsProcessorConfig", "ExpressionColumnConfig", "GaussianSamplerParams", + "IndexRange", "ImageContext", "ImageFormat", "InferenceParameters", @@ -104,6 +105,7 @@ "ModalityContext", "ModalityDataType", "ModelConfig", + "PartitionBlock", "PersonSamplerParams", "PoissonSamplerParams", "ProcessorType", diff --git a/tests/config/test_seed.py b/tests/config/test_seed.py new file mode 100644 index 00000000..2e7a2e1b --- /dev/null +++ b/tests/config/test_seed.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from data_designer.config.seed import IndexRange, PartitionBlock + + +def test_index_range_validation(): + with pytest.raises(ValueError, match="should be greater than or equal to 0"): + IndexRange(start=-1, end=10) + + with pytest.raises(ValueError, match="should be greater than or equal to 0"): + IndexRange(start=0, end=-1) + + with pytest.raises(ValueError, match="'start' index must be less than or equal to 'end' index"): + IndexRange(start=11, end=10) + + +def test_index_range_size(): + assert IndexRange(start=0, end=10).size == 11 + assert IndexRange(start=1, end=10).size == 10 + assert IndexRange(start=0, end=0).size == 1 + + +def test_partition_block_validation(): + with pytest.raises(ValueError, match="should be greater than or equal to 0"): + PartitionBlock(index=-1, num_partitions=10) + + with pytest.raises(ValueError, match="should be greater than or equal to 1"): + PartitionBlock(index=0, num_partitions=0) + + with pytest.raises(ValueError, match="'index' must be less than 'num_partitions'"): + PartitionBlock(index=10, num_partitions=10) + + +def test_partition_block_to_index_range(): + index_range = PartitionBlock(index=0, num_partitions=10).to_index_range(101) + assert index_range.start == 0 + assert index_range.end == 9 + assert index_range.size == 10 + + index_range = PartitionBlock(index=1, num_partitions=10).to_index_range(105) + assert index_range.start == 10 + assert index_range.end == 19 + assert index_range.size == 10 + + index_range = PartitionBlock(index=2, num_partitions=10).to_index_range(105) + assert index_range.start == 20 + assert index_range.end == 29 + assert index_range.size == 10 + + index_range = PartitionBlock(index=9, num_partitions=10).to_index_range(105) + assert index_range.start == 90 + assert index_range.end == 104 + assert index_range.size == 15 diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index cebb68ca..487d143f 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -10,14 +10,15 @@ import pytest from data_designer.config.columns import SeedDatasetColumnConfig -from data_designer.config.seed import SamplingStrategy +from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.seed_dataset import ( MAX_ZERO_RECORD_RESPONSE_FACTOR, SeedDatasetColumnGenerator, ) +from data_designer.engine.column_generators.utils.errors import SeedDatasetError from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig -from data_designer.engine.resources.resource_provider import ResourceType +from data_designer.engine.resources.resource_provider import ResourceProvider, ResourceType @pytest.fixture @@ -123,6 +124,30 @@ def test_seed_dataset_column_generator_config_structure(): assert config.columns[0].column_type.value == "seed-dataset" assert config.columns[1].name == "col2" assert config.columns[1].column_type.value == "seed-dataset" + assert config.selection_strategy is None + + # Test PartitionBlock selection strategy + config = SeedDatasetMultiColumnConfig( + columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")], + dataset="test/dataset", + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=PartitionBlock(index=1, num_partitions=3), + ) + assert isinstance(config.selection_strategy, PartitionBlock) + assert config.selection_strategy.index == 1 + assert config.selection_strategy.num_partitions == 3 + + # Test IndexRange selection strategy + config = SeedDatasetMultiColumnConfig( + columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")], + dataset="test/dataset", + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=IndexRange(start=0, end=1), + ) + assert isinstance(config.selection_strategy, IndexRange) + assert config.selection_strategy.start == 0 + assert config.selection_strategy.end == 1 + assert config.selection_strategy.size == 2 # Test constants and enum values assert MAX_ZERO_RECORD_RESPONSE_FACTOR == 2 @@ -333,7 +358,12 @@ def test_seed_dataset_column_generator_sample_records_multiple_batches(stub_seed # ============================================================================ -def create_generator_with_real_file(file_path: str, stub_resource_provider) -> SeedDatasetColumnGenerator: +def create_generator_with_real_file( + file_path: str, + stub_resource_provider: ResourceProvider, + sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, + selection_strategy: IndexRange | PartitionBlock | None = None, +) -> SeedDatasetColumnGenerator: """Helper function to create a generator with a real file and DuckDB connection.""" config = SeedDatasetMultiColumnConfig( columns=[ @@ -344,7 +374,8 @@ def create_generator_with_real_file(file_path: str, stub_resource_provider) -> S SeedDatasetColumnConfig(name="score"), ], dataset=f"test/{os.path.basename(file_path)}", - sampling_strategy=SamplingStrategy.ORDERED, + sampling_strategy=sampling_strategy, + selection_strategy=selection_strategy, ) # Create a real DuckDB connection (in-memory by default) @@ -605,3 +636,153 @@ def test_seed_dataset_generator_uses_real_duckdb_connection(fixture_name, stub_r # Verify the connection can execute count queries count_result = generator.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{file_path}'").fetchone()[0] assert count_result == 10 + + +# ============================================================================ +# Tests for SeedConfig selection strategies +# ============================================================================ +@pytest.mark.parametrize( + "fixture_name", + [ + "seed_dataset_parquet", + "seed_dataset_csv", + "seed_dataset_json", + "seed_dataset_jsonl", + ], +) +def test_seed_dataset_generator_index_range_selection_strategy(fixture_name, stub_resource_provider, request): + """Test that generator correctly applies index range selection strategy.""" + # Ordered Sampling + + # Range with a subset of items + file_path = request.getfixturevalue(fixture_name) + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.ORDERED, + selection_strategy=IndexRange(start=4, end=8), + ) + result = generator.generate_from_scratch(6) + assert len(result) == 6 + assert list(result["name"]) == ["Eve", "Frank", "Grace", "Henry", "Ivy", "Eve"] + + # Range with just one item + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.ORDERED, + selection_strategy=IndexRange(start=4, end=4), + ) + result = generator.generate_from_scratch(1) + assert len(result) == 1 + assert list(result["name"]) == ["Eve"] + + # Range with all items + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.ORDERED, + selection_strategy=IndexRange(start=0, end=9), + ) + result = generator.generate_from_scratch(10) + assert len(result) == 10 + assert list(result["name"]) == ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack"] + + # Shuffle Sampling + + # Range with a subset of items + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=IndexRange(start=4, end=8), + ) + result = generator.generate_from_scratch(10) + assert len(result) == 10 + assert set(result["name"]).issubset({"Eve", "Frank", "Grace", "Henry", "Ivy"}) + + # Range with just one item + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=IndexRange(start=4, end=4), + ) + result = generator.generate_from_scratch(1) + assert len(result) == 1 + assert list(result["name"]) == ["Eve"] + + # Range with all items + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=IndexRange(start=0, end=9), + ) + result = generator.generate_from_scratch(10) + assert len(result) == 10 + assert set(result["name"]).issubset( + {"Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack"} + ) + + +@pytest.mark.parametrize( + "fixture_name", + [ + "seed_dataset_parquet", + "seed_dataset_csv", + "seed_dataset_json", + "seed_dataset_jsonl", + ], +) +def test_seed_dataset_generator_partition_block_selection_strategy(fixture_name, stub_resource_provider, request): + """Test that generator correctly applies partition block selection strategy.""" + file_path = request.getfixturevalue(fixture_name) + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.ORDERED, + selection_strategy=PartitionBlock(index=1, num_partitions=3), + ) + result = generator.generate_from_scratch(5) + assert len(result) == 5 + # Requesting 5 items from a 3-item partition should cycle: + assert list(result["name"]) == ["David", "Eve", "Frank", "David", "Eve"] + + generator = create_generator_with_real_file( + file_path, + stub_resource_provider, + sampling_strategy=SamplingStrategy.SHUFFLE, + selection_strategy=PartitionBlock(index=4, num_partitions=5), + ) + result = generator.generate_from_scratch(10) + assert len(result) == 10 + assert set(result["name"]).issubset({"Jack", "Ivy"}) + + +@pytest.mark.parametrize( + "fixture_name", + [ + "seed_dataset_parquet", + "seed_dataset_csv", + "seed_dataset_json", + "seed_dataset_jsonl", + ], +) +def test_seed_dataset_generator_invalid_selection_strategies(fixture_name, stub_resource_provider, request): + """Test that generator raises an error for invalid selection strategies.""" + file_path = request.getfixturevalue(fixture_name) + with pytest.raises( + SeedDatasetError, match="Selection strategy 'end' index 10 is out of bounds for dataset size 10" + ): + generator = create_generator_with_real_file( + file_path, stub_resource_provider, selection_strategy=IndexRange(start=1, end=10) + ) + generator.generate_from_scratch(1) + with pytest.raises( + SeedDatasetError, match="Selection strategy 'num_partitions' 11 is out of bounds for dataset size 10" + ): + generator = create_generator_with_real_file( + file_path, stub_resource_provider, selection_strategy=PartitionBlock(index=0, num_partitions=11) + ) + generator.generate_from_scratch(1)