Skip to content

Commit 886c7ba

Browse files
committed
Small optimization of the slice_flatter iterator
1 parent 07c7001 commit 886c7ba

File tree

1 file changed

+48
-15
lines changed

1 file changed

+48
-15
lines changed

src/blosc2/blosc2_ext.pyx

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,7 @@ cdef class slice_flatter:
22432243
cdef long long[:] indices
22442244
cdef long long current_slice_start
22452245
cdef long long current_slice_end
2246+
cdef long long current_flat_idx # Track the current flat index
22462247

22472248
def __cinit__(self, long long[:] start not None, long long[:] stop not None, long long[:] strides not None):
22482249
self.ndim = start.shape[0]
@@ -2255,41 +2256,73 @@ cdef class slice_flatter:
22552256
shape = tuple(stop[i] - start[i] for i in range(self.ndim))
22562257
self.shape = np.array(shape, dtype=np.int64)
22572258
self.indices = np.zeros(self.ndim, dtype=np.int64)
2259+
# Initialize the flat index
2260+
self.current_flat_idx = 0
2261+
for j in range(self.ndim):
2262+
self.current_flat_idx += self.start[j] * self.strides[j]
22582263

22592264
def __iter__(self):
22602265
return self
22612266

22622267
@cython.boundscheck(False)
22632268
@cython.wraparound(False)
22642269
def __next__(self):
2265-
cdef long long i, j, flat_idx
2270+
cdef long long j, next_flat_idx
2271+
cdef int extended_slice = 0
22662272

2273+
# Check if we're done
2274+
if self.done:
2275+
if self.current_slice_start != -1:
2276+
result = slice(self.current_slice_start, self.current_slice_end + 1)
2277+
self.current_slice_start = -1
2278+
return result
2279+
raise StopIteration
2280+
2281+
# Initialize first slice point if needed
2282+
if self.current_slice_start == -1:
2283+
next_flat_idx = 0
2284+
for j in range(self.ndim):
2285+
next_flat_idx += (self.start[j] + self.indices[j]) * self.strides[j]
2286+
self.current_slice_start = next_flat_idx
2287+
self.current_slice_end = next_flat_idx
2288+
self.current_flat_idx = next_flat_idx
2289+
self.incr_indices()
2290+
2291+
# If we're done after the first element, return it
2292+
if self.done:
2293+
result = slice(self.current_slice_start, self.current_slice_end + 1)
2294+
self.current_slice_start = -1
2295+
return result
2296+
2297+
# Extend slice as long as indices remain contiguous
22672298
while not self.done:
2268-
flat_idx = 0
2299+
# Calculate next flat index
2300+
next_flat_idx = 0
22692301
for j in range(self.ndim):
2270-
flat_idx += (self.start[j] + self.indices[j]) * self.strides[j]
2302+
next_flat_idx += (self.start[j] + self.indices[j]) * self.strides[j]
22712303

2272-
if self.current_slice_start == -1:
2273-
self.current_slice_start = flat_idx
2274-
self.current_slice_end = flat_idx
2275-
elif flat_idx == self.current_slice_end + 1:
2276-
self.current_slice_end = flat_idx
2304+
# If indices are contiguous, extend current slice
2305+
if next_flat_idx == self.current_slice_end + 1:
2306+
self.current_slice_end = next_flat_idx
2307+
self.current_flat_idx = next_flat_idx
2308+
self.incr_indices()
2309+
extended_slice = 1
22772310
else:
2311+
# Non-contiguous index found, return current slice
22782312
result = slice(self.current_slice_start, self.current_slice_end + 1)
2279-
self.current_slice_start = flat_idx
2280-
self.current_slice_end = flat_idx
2281-
# Increment the indices (TODO: remove duplicated code)
2313+
self.current_slice_start = next_flat_idx
2314+
self.current_slice_end = next_flat_idx
2315+
self.current_flat_idx = next_flat_idx
22822316
self.incr_indices()
22832317
return result
22842318

2285-
# Increment the indices
2286-
self.incr_indices()
2287-
2288-
if self.current_slice_start != -1:
2319+
# If we've reached the end after extending the slice
2320+
if extended_slice:
22892321
result = slice(self.current_slice_start, self.current_slice_end + 1)
22902322
self.current_slice_start = -1
22912323
return result
22922324

2325+
# Should never reach here
22932326
raise StopIteration
22942327

22952328
@cython.boundscheck(False)

0 commit comments

Comments
 (0)