Skip to content

Commit b429464

Browse files
committed
add IndexRange and PartitionBlock
1 parent 19e5e46 commit b429464

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/data_designer/config/seed.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from abc import ABC
55
from enum import Enum
66

7-
from pydantic import field_validator
8-
7+
from pydantic import field_validator, Field, model_validator
8+
from typing_extensions import Self
9+
from typing import Union, Optional
910
from .base import ConfigBase
1011
from .datastore import DatastoreSettings
1112
from .utils.io_helpers import validate_dataset_file_path
@@ -16,9 +17,32 @@ class SamplingStrategy(str, Enum):
1617
SHUFFLE = "shuffle"
1718

1819

20+
class IndexRange(ConfigBase):
21+
start: int = Field(..., ge=0)
22+
end: int = Field(..., ge=1)
23+
24+
@model_validator(mode="after")
25+
def _validate_index_range(self) -> Self:
26+
if self.start >= self.end:
27+
raise ValueError("'start' index must be less than 'end' index")
28+
return self
29+
30+
31+
class PartitionBlock(ConfigBase):
32+
partition_index: int = Field(..., default=0, ge=0)
33+
num_partitions: int = Field(..., default=1, ge=1)
34+
35+
@model_validator(mode="after")
36+
def _validate_partition_block(self) -> Self:
37+
if self.partition_index >= self.num_partitions:
38+
raise ValueError("'partition_index' must be less than 'num_partitions'")
39+
return self
40+
41+
1942
class SeedConfig(ConfigBase):
2043
dataset: str
2144
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
45+
selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None
2246

2347

2448
class SeedDatasetReference(ABC, ConfigBase):

src/data_designer/essentials/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
UniformSamplerParams,
4848
UUIDSamplerParams,
4949
)
50-
from ..config.seed import DatastoreSeedDatasetReference, SamplingStrategy, SeedConfig
50+
from ..config.seed import DatastoreSeedDatasetReference, SamplingStrategy, SeedConfig, IndexRange, PartitionBlock
5151
from ..config.utils.code_lang import CodeLang
5252
from ..config.utils.misc import can_run_data_designer_locally
5353
from ..config.validator_params import (
@@ -85,6 +85,7 @@
8585
"DatetimeSamplerParams",
8686
"ExpressionColumnConfig",
8787
"GaussianSamplerParams",
88+
"IndexRange",
8889
"ImageContext",
8990
"ImageFormat",
9091
"InferenceParameters",
@@ -100,6 +101,7 @@
100101
"ModalityContext",
101102
"ModalityDataType",
102103
"ModelConfig",
104+
"PartitionBlock",
103105
"PersonSamplerParams",
104106
"PoissonSamplerParams",
105107
"RemoteValidatorParams",

0 commit comments

Comments
 (0)