Skip to content

Commit fd61ec4

Browse files
authored
ChunkIterator: account for smaller last slice (TGSAI#702)
1 parent 7392241 commit fd61ec4

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

src/mdio/core/indexing.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def __next__(self) -> dict[str, slice]:
7474
# We build slices here. It is dimension agnostic
7575
current_start = next(self._ranges)
7676

77-
# TODO (Dmitriy Repin): Enhance ChunkIterator to make the last slice, if needed, smaller
78-
# https://github.com/TGSAI/mdio-python/issues/586
7977
start_indices = tuple(dim * chunk for dim, chunk in zip(current_start, self.len_chunks, strict=True))
8078

81-
stop_indices = tuple((dim + 1) * chunk for dim, chunk in zip(current_start, self.len_chunks, strict=True))
79+
# Calculate stop indices, making the last slice fit the data exactly
80+
stop_indices = tuple(
81+
min((dim + 1) * chunk, self.arr_shape[i])
82+
for i, (dim, chunk) in enumerate(zip(current_start, self.len_chunks, strict=True))
83+
)
8284

8385
slices = tuple(slice(start, stop) for start, stop in zip(start_indices, stop_indices, strict=True))
8486

tests/unit/test_indexing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def test_chunk_iterator_returning_dict() -> None:
2525
assert iter2.dim_chunks == (2, 3, 4)
2626
assert iter2.num_chunks == 24
2727

28-
# Its purpose is to confirm that all slices are created of the same size,
29-
# even if the last slice should have been smaller.
28+
# Its purpose is to confirm that the last slice is adjusted to fit the data exactly
29+
# when the array size doesn't align perfectly with chunk boundaries.
3030
for _ in range(13): # element index 12
3131
region = iter1.__next__()
3232
assert region == {
@@ -38,7 +38,7 @@ def test_chunk_iterator_returning_dict() -> None:
3838
for _ in range(13): # element index 12
3939
region = iter2.__next__()
4040
assert region == {
41-
"inline": slice(3, 6, None),
41+
"inline": slice(3, 5, None),
4242
"crossline": slice(0, 4, None),
4343
"depth": slice(0, 5, None),
4444
}
@@ -61,15 +61,15 @@ def test_chunk_iterator_returning_tuple() -> None:
6161
assert iter2.dim_chunks == (2, 3, 4)
6262
assert iter2.num_chunks == 24
6363

64-
# Its purpose is to confirm that all slices are created of the same size,
65-
# even if the last slice should have been smaller.
64+
# Its purpose is to confirm that the last slice is adjusted to fit the data exactly
65+
# when the array size doesn't align perfectly with chunk boundaries.
6666
for _ in range(13): # element index 12
6767
region = iter1.__next__()
6868
assert region == (slice(3, 6, None), slice(0, 4, None), slice(0, 5, None))
6969

7070
for _ in range(13): # element index 12
7171
region = iter2.__next__()
72-
assert region == (slice(3, 6, None), slice(0, 4, None), slice(0, 5, None))
72+
assert region == (slice(3, 5, None), slice(0, 4, None), slice(0, 5, None))
7373

7474

7575
def val(shape: tuple[int, int, int], i: int, j: int, k: int) -> int:

0 commit comments

Comments
 (0)