@@ -2243,6 +2243,7 @@ cdef class slice_flatter:
22432243 cdef long long [:] indices
22442244 cdef long long current_slice_start
22452245 cdef long long current_slice_end
2246+ cdef long long current_flat_idx # Track the current flat index
22462247
22472248 def __cinit__ (self , long long[:] start not None , long long[:] stop not None , long long[:] strides not None ):
22482249 self .ndim = start.shape[0 ]
@@ -2255,41 +2256,73 @@ cdef class slice_flatter:
22552256 shape = tuple (stop[i] - start[i] for i in range (self .ndim))
22562257 self .shape = np.array(shape, dtype = np.int64)
22572258 self .indices = np.zeros(self .ndim, dtype = np.int64)
2259+ # Initialize the flat index
2260+ self .current_flat_idx = 0
2261+ for j in range (self .ndim):
2262+ self .current_flat_idx += self .start[j] * self .strides[j]
22582263
22592264 def __iter__ (self ):
22602265 return self
22612266
22622267 @ cython.boundscheck (False )
22632268 @ cython.wraparound (False )
22642269 def __next__ (self ):
2265- cdef long long i, j, flat_idx
2270+ cdef long long j, next_flat_idx
2271+ cdef int extended_slice = 0
22662272
2273+ # Check if we're done
2274+ if self .done:
2275+ if self .current_slice_start != - 1 :
2276+ result = slice (self .current_slice_start, self .current_slice_end + 1 )
2277+ self .current_slice_start = - 1
2278+ return result
2279+ raise StopIteration
2280+
2281+ # Initialize first slice point if needed
2282+ if self .current_slice_start == - 1 :
2283+ next_flat_idx = 0
2284+ for j in range (self .ndim):
2285+ next_flat_idx += (self .start[j] + self .indices[j]) * self .strides[j]
2286+ self .current_slice_start = next_flat_idx
2287+ self .current_slice_end = next_flat_idx
2288+ self .current_flat_idx = next_flat_idx
2289+ self .incr_indices()
2290+
2291+ # If we're done after the first element, return it
2292+ if self .done:
2293+ result = slice (self .current_slice_start, self .current_slice_end + 1 )
2294+ self .current_slice_start = - 1
2295+ return result
2296+
2297+ # Extend slice as long as indices remain contiguous
22672298 while not self .done:
2268- flat_idx = 0
2299+ # Calculate next flat index
2300+ next_flat_idx = 0
22692301 for j in range (self .ndim):
2270- flat_idx += (self .start[j] + self .indices[j]) * self .strides[j]
2302+ next_flat_idx += (self .start[j] + self .indices[j]) * self .strides[j]
22712303
2272- if self .current_slice_start == - 1 :
2273- self .current_slice_start = flat_idx
2274- self .current_slice_end = flat_idx
2275- elif flat_idx == self .current_slice_end + 1 :
2276- self .current_slice_end = flat_idx
2304+ # If indices are contiguous, extend current slice
2305+ if next_flat_idx == self .current_slice_end + 1 :
2306+ self .current_slice_end = next_flat_idx
2307+ self .current_flat_idx = next_flat_idx
2308+ self .incr_indices()
2309+ extended_slice = 1
22772310 else :
2311+ # Non-contiguous index found, return current slice
22782312 result = slice (self .current_slice_start, self .current_slice_end + 1 )
2279- self .current_slice_start = flat_idx
2280- self .current_slice_end = flat_idx
2281- # Increment the indices (TODO: remove duplicated code)
2313+ self .current_slice_start = next_flat_idx
2314+ self .current_slice_end = next_flat_idx
2315+ self .current_flat_idx = next_flat_idx
22822316 self .incr_indices()
22832317 return result
22842318
2285- # Increment the indices
2286- self .incr_indices()
2287-
2288- if self .current_slice_start != - 1 :
2319+ # If we've reached the end after extending the slice
2320+ if extended_slice:
22892321 result = slice (self .current_slice_start, self .current_slice_end + 1 )
22902322 self .current_slice_start = - 1
22912323 return result
22922324
2325+ # Should never reach here
22932326 raise StopIteration
22942327
22952328 @ cython.boundscheck (False )
0 commit comments