Skip to content

Commit 80fca51

Browse files
committed
partition_index -> index
1 parent 6da2c3c commit 80fca51

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/data_designer/config/seed.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,24 @@ def size(self) -> int:
3434

3535

3636
class PartitionBlock(ConfigBase):
37-
partition_index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
37+
index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
3838
num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")
3939

4040
@model_validator(mode="after")
4141
def _validate_partition_block(self) -> Self:
42-
if self.partition_index >= self.num_partitions:
43-
raise ValueError("'partition_index' must be less than 'num_partitions'")
42+
if self.index >= self.num_partitions:
43+
raise ValueError("'index' must be less than 'num_partitions'")
4444
return self
4545

4646
def to_index_range(self, dataset_size: int) -> IndexRange:
4747
partition_size = dataset_size // self.num_partitions
48-
start = self.partition_index * partition_size
48+
start = self.index * partition_size
4949

5050
# For the last partition, extend to the end of the dataset to include remainder rows
51-
if self.partition_index == self.num_partitions - 1:
51+
if self.index == self.num_partitions - 1:
5252
end = dataset_size - 1
5353
else:
54-
end = ((self.partition_index + 1) * partition_size) - 1
54+
end = ((self.index + 1) * partition_size) - 1
5555
return IndexRange(start=start, end=end)
5656

5757

tests/config/test_seed.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,32 @@ def test_index_range_size():
2525

2626
def test_partition_block_validation():
2727
with pytest.raises(ValueError, match="should be greater than or equal to 0"):
28-
PartitionBlock(partition_index=-1, num_partitions=10)
28+
PartitionBlock(index=-1, num_partitions=10)
2929

3030
with pytest.raises(ValueError, match="should be greater than or equal to 1"):
31-
PartitionBlock(partition_index=0, num_partitions=0)
31+
PartitionBlock(index=0, num_partitions=0)
3232

33-
with pytest.raises(ValueError, match="'partition_index' must be less than 'num_partitions'"):
34-
PartitionBlock(partition_index=10, num_partitions=10)
33+
with pytest.raises(ValueError, match="'index' must be less than 'num_partitions'"):
34+
PartitionBlock(index=10, num_partitions=10)
3535

3636

3737
def test_partition_block_to_index_range():
38-
index_range = PartitionBlock(partition_index=0, num_partitions=10).to_index_range(101)
38+
index_range = PartitionBlock(index=0, num_partitions=10).to_index_range(101)
3939
assert index_range.start == 0
4040
assert index_range.end == 9
4141
assert index_range.size == 10
4242

43-
index_range = PartitionBlock(partition_index=1, num_partitions=10).to_index_range(105)
43+
index_range = PartitionBlock(index=1, num_partitions=10).to_index_range(105)
4444
assert index_range.start == 10
4545
assert index_range.end == 19
4646
assert index_range.size == 10
4747

48-
index_range = PartitionBlock(partition_index=2, num_partitions=10).to_index_range(105)
48+
index_range = PartitionBlock(index=2, num_partitions=10).to_index_range(105)
4949
assert index_range.start == 20
5050
assert index_range.end == 29
5151
assert index_range.size == 10
5252

53-
index_range = PartitionBlock(partition_index=9, num_partitions=10).to_index_range(105)
53+
index_range = PartitionBlock(index=9, num_partitions=10).to_index_range(105)
5454
assert index_range.start == 90
5555
assert index_range.end == 104
5656
assert index_range.size == 15

tests/engine/column_generators/generators/test_seed_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,10 @@ def test_seed_dataset_column_generator_config_structure():
131131
columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")],
132132
dataset="test/dataset",
133133
sampling_strategy=SamplingStrategy.SHUFFLE,
134-
selection_strategy=PartitionBlock(partition_index=1, num_partitions=3),
134+
selection_strategy=PartitionBlock(index=1, num_partitions=3),
135135
)
136136
assert isinstance(config.selection_strategy, PartitionBlock)
137-
assert config.selection_strategy.partition_index == 1
137+
assert config.selection_strategy.index == 1
138138
assert config.selection_strategy.num_partitions == 3
139139

140140
# Test IndexRange selection strategy
@@ -742,7 +742,7 @@ def test_seed_dataset_generator_partition_block_selection_strategy(fixture_name,
742742
file_path,
743743
stub_resource_provider,
744744
sampling_strategy=SamplingStrategy.ORDERED,
745-
selection_strategy=PartitionBlock(partition_index=1, num_partitions=3),
745+
selection_strategy=PartitionBlock(index=1, num_partitions=3),
746746
)
747747
result = generator.generate_from_scratch(5)
748748
assert len(result) == 5
@@ -753,7 +753,7 @@ def test_seed_dataset_generator_partition_block_selection_strategy(fixture_name,
753753
file_path,
754754
stub_resource_provider,
755755
sampling_strategy=SamplingStrategy.SHUFFLE,
756-
selection_strategy=PartitionBlock(partition_index=4, num_partitions=5),
756+
selection_strategy=PartitionBlock(index=4, num_partitions=5),
757757
)
758758
result = generator.generate_from_scratch(10)
759759
assert len(result) == 10
@@ -783,6 +783,6 @@ def test_seed_dataset_generator_invalid_selection_strategies(fixture_name, stub_
783783
SeedDatasetError, match="Selection strategy 'num_partitions' 11 is out of bounds for dataset size 10"
784784
):
785785
generator = create_generator_with_real_file(
786-
file_path, stub_resource_provider, selection_strategy=PartitionBlock(partition_index=0, num_partitions=11)
786+
file_path, stub_resource_provider, selection_strategy=PartitionBlock(index=0, num_partitions=11)
787787
)
788788
generator.generate_from_scratch(1)

0 commit comments

Comments
 (0)