44from abc import ABC
55from enum import Enum
66
7- from pydantic import field_validator
8-
7+ from pydantic import field_validator , Field , model_validator
8+ from typing_extensions import Self
9+ from typing import Union , Optional
910from .base import ConfigBase
1011from .datastore import DatastoreSettings
1112from .utils .io_helpers import validate_dataset_file_path
@@ -16,9 +17,32 @@ class SamplingStrategy(str, Enum):
1617 SHUFFLE = "shuffle"
1718
1819
20+ class IndexRange (ConfigBase ):
21+ start : int = Field (..., ge = 0 )
22+ end : int = Field (..., ge = 1 )
23+
24+ @model_validator (mode = "after" )
25+ def _validate_index_range (self ) -> Self :
26+ if self .start >= self .end :
27+ raise ValueError ("'start' index must be less than 'end' index" )
28+ return self
29+
30+
31+ class PartitionBlock (ConfigBase ):
32+ partition_index : int = Field (..., default = 0 , ge = 0 )
33+ num_partitions : int = Field (..., default = 1 , ge = 1 )
34+
35+ @model_validator (mode = "after" )
36+ def _validate_partition_block (self ) -> Self :
37+ if self .partition_index >= self .num_partitions :
38+ raise ValueError ("'partition_index' must be less than 'num_partitions'" )
39+ return self
40+
41+
1942class SeedConfig (ConfigBase ):
2043 dataset : str
2144 sampling_strategy : SamplingStrategy = SamplingStrategy .ORDERED
45+ selection_strategy : Optional [Union [IndexRange , PartitionBlock ]] = None
2246
2347
2448class SeedDatasetReference (ABC , ConfigBase ):
0 commit comments