1010import pytest
1111
1212from data_designer .config .columns import SeedDatasetColumnConfig
13- from data_designer .config .seed import SamplingStrategy
13+ from data_designer .config .seed import SamplingStrategy , IndexRange , PartitionBlock
1414from data_designer .engine .column_generators .generators .base import GenerationStrategy
1515from data_designer .engine .column_generators .generators .seed_dataset import (
1616 MAX_ZERO_RECORD_RESPONSE_FACTOR ,
1717 SeedDatasetColumnGenerator ,
1818)
1919from data_designer .engine .dataset_builders .multi_column_configs import SeedDatasetMultiColumnConfig
20- from data_designer .engine .resources .resource_provider import ResourceType
20+ from data_designer .engine .resources .resource_provider import ResourceType , ResourceProvider
21+ from data_designer .engine .column_generators .utils .errors import SeedDatasetError
2122
2223
2324@pytest .fixture
@@ -333,7 +334,11 @@ def test_seed_dataset_column_generator_sample_records_multiple_batches(stub_seed
333334# ============================================================================
334335
335336
336- def create_generator_with_real_file (file_path : str , stub_resource_provider ) -> SeedDatasetColumnGenerator :
337+ def create_generator_with_real_file (
338+ file_path : str ,
339+ stub_resource_provider : ResourceProvider ,
340+ sampling_strategy : SamplingStrategy = SamplingStrategy .ORDERED ,
341+ selection_strategy : IndexRange | PartitionBlock | None = None ) -> SeedDatasetColumnGenerator :
337342 """Helper function to create a generator with a real file and DuckDB connection."""
338343 config = SeedDatasetMultiColumnConfig (
339344 columns = [
@@ -344,7 +349,8 @@ def create_generator_with_real_file(file_path: str, stub_resource_provider) -> S
344349 SeedDatasetColumnConfig (name = "score" ),
345350 ],
346351 dataset = f"test/{ os .path .basename (file_path )} " ,
347- sampling_strategy = SamplingStrategy .ORDERED ,
352+ sampling_strategy = sampling_strategy ,
353+ selection_strategy = selection_strategy ,
348354 )
349355
350356 # Create a real DuckDB connection (in-memory by default)
@@ -605,3 +611,112 @@ def test_seed_dataset_generator_uses_real_duckdb_connection(fixture_name, stub_r
605611 # Verify the connection can execute count queries
606612 count_result = generator .duckdb_conn .execute (f"SELECT COUNT(*) FROM '{ file_path } '" ).fetchone ()[0 ]
607613 assert count_result == 10
614+
615+
616+ # ============================================================================
617+ # Tests for SeedConfig selection strategies
618+ # ============================================================================
619+ @pytest .mark .parametrize (
620+ "fixture_name" ,
621+ [
622+ "seed_dataset_parquet" ,
623+ "seed_dataset_csv" ,
624+ "seed_dataset_json" ,
625+ "seed_dataset_jsonl" ,
626+ ],
627+ )
628+ def test_seed_dataset_generator_index_range_selection_strategy (fixture_name , stub_resource_provider , request ):
629+ """Test that generator correctly applies index range selection strategy."""
630+ # Ordered Sampling
631+
632+ # Range with a subset of items
633+ file_path = request .getfixturevalue (fixture_name )
634+ generator = create_generator_with_real_file (file_path , stub_resource_provider , sampling_strategy = SamplingStrategy .ORDERED , selection_strategy = IndexRange (start = 4 , end = 8 ))
635+ result = generator .generate_from_scratch (6 )
636+ assert len (result ) == 6
637+ assert list (result ["name" ]) == ["Eve" , "Frank" , "Grace" , "Henry" , "Ivy" , "Eve" ]
638+
639+ # Range with just one item
640+ generator = create_generator_with_real_file (file_path , stub_resource_provider , sampling_strategy = SamplingStrategy .ORDERED , selection_strategy = IndexRange (start = 4 , end = 4 ))
641+ result = generator .generate_from_scratch (1 )
642+ assert len (result ) == 1
643+ assert list (result ["name" ]) == ["Eve" ]
644+
645+ # Range with all items
646+ generator = create_generator_with_real_file (file_path , stub_resource_provider , sampling_strategy = SamplingStrategy .ORDERED , selection_strategy = IndexRange (start = 0 , end = 9 ))
647+ result = generator .generate_from_scratch (10 )
648+ assert len (result ) == 10
649+ assert list (result ["name" ]) == ["Alice" , "Bob" , "Charlie" , "David" , "Eve" , "Frank" , "Grace" , "Henry" , "Ivy" , "Jack" ]
650+
651+ # Shuffle Sampling
652+
653+ # Range with a subset of items
654+ generator = create_generator_with_real_file (file_path , stub_resource_provider , sampling_strategy = SamplingStrategy .SHUFFLE , selection_strategy = IndexRange (start = 4 , end = 8 ))
655+ result = generator .generate_from_scratch (10 )
656+ assert len (result ) == 10
657+ assert set (result ["name" ]).issubset ({"Eve" , "Frank" , "Grace" , "Henry" , "Ivy" })
658+
659+ # Range with just one item
660+ generator = create_generator_with_real_file (file_path , stub_resource_provider , sampling_strategy = SamplingStrategy .SHUFFLE , selection_strategy = IndexRange (start = 4 , end = 4 ))
661+ result = generator .generate_from_scratch (1 )
662+ assert len (result ) == 1
663+ assert list (result ["name" ]) == ["Eve" ]
664+
665+ # Range with all items
666+ generator = create_generator_with_real_file (file_path , stub_resource_provider , sampling_strategy = SamplingStrategy .SHUFFLE , selection_strategy = IndexRange (start = 0 , end = 9 ))
667+ result = generator .generate_from_scratch (10 )
668+ assert len (result ) == 10
669+ assert set (result ["name" ]).issubset ({"Alice" , "Bob" , "Charlie" , "David" , "Eve" , "Frank" , "Grace" , "Henry" , "Ivy" , "Jack" })
670+
671+
672+ @pytest .mark .parametrize (
673+ "fixture_name" ,
674+ [
675+ "seed_dataset_parquet" ,
676+ "seed_dataset_csv" ,
677+ "seed_dataset_json" ,
678+ "seed_dataset_jsonl" ,
679+ ],
680+ )
681+ def test_seed_dataset_generator_partition_block_selection_strategy (fixture_name , stub_resource_provider , request ):
682+ """Test that generator correctly applies partition block selection strategy."""
683+ file_path = request .getfixturevalue (fixture_name )
684+ generator = create_generator_with_real_file (
685+ file_path ,
686+ stub_resource_provider ,
687+ sampling_strategy = SamplingStrategy .ORDERED ,
688+ selection_strategy = PartitionBlock (partition_index = 1 , num_partitions = 3 )
689+ )
690+ result = generator .generate_from_scratch (5 )
691+ assert len (result ) == 5
692+ # Requesting 5 items from a 3-item partition should cycle:
693+ assert list (result ["name" ]) == ["David" , "Eve" , "Frank" , "David" , "Eve" ]
694+
695+ generator = create_generator_with_real_file (
696+ file_path ,
697+ stub_resource_provider ,
698+ sampling_strategy = SamplingStrategy .SHUFFLE ,
699+ selection_strategy = PartitionBlock (partition_index = 4 , num_partitions = 5 ))
700+ result = generator .generate_from_scratch (10 )
701+ assert len (result ) == 10
702+ assert set (result ["name" ]).issubset ({"Jack" , "Ivy" })
703+
704+
705+ @pytest .mark .parametrize (
706+ "fixture_name" ,
707+ [
708+ "seed_dataset_parquet" ,
709+ "seed_dataset_csv" ,
710+ "seed_dataset_json" ,
711+ "seed_dataset_jsonl" ,
712+ ],
713+ )
714+ def test_seed_dataset_generator_invalid_selection_strategies (fixture_name , stub_resource_provider , request ):
715+ """Test that generator raises an error for invalid selection strategies."""
716+ file_path = request .getfixturevalue (fixture_name )
717+ with pytest .raises (SeedDatasetError , match = "Selection strategy 'end' index 10 is out of bounds for dataset size 10" ):
718+ generator = create_generator_with_real_file (file_path , stub_resource_provider , selection_strategy = IndexRange (start = 1 , end = 10 ))
719+ generator .generate_from_scratch (1 )
720+ with pytest .raises (SeedDatasetError , match = "Selection strategy 'num_partitions' 11 is out of bounds for dataset size 10" ):
721+ generator = create_generator_with_real_file (file_path , stub_resource_provider , selection_strategy = PartitionBlock (partition_index = 0 , num_partitions = 11 ))
722+ generator .generate_from_scratch (1 )
0 commit comments