@@ -2243,6 +2243,7 @@ cdef class slice_flatter:
2243
2243
cdef long long [:] indices
2244
2244
cdef long long current_slice_start
2245
2245
cdef long long current_slice_end
2246
+ cdef long long current_flat_idx # Track the current flat index
2246
2247
2247
2248
def __cinit__ (self , long long[:] start not None , long long[:] stop not None , long long[:] strides not None ):
2248
2249
self .ndim = start.shape[0 ]
@@ -2255,41 +2256,73 @@ cdef class slice_flatter:
2255
2256
shape = tuple (stop[i] - start[i] for i in range (self .ndim))
2256
2257
self .shape = np.array(shape, dtype = np.int64)
2257
2258
self .indices = np.zeros(self .ndim, dtype = np.int64)
2259
+ # Initialize the flat index
2260
+ self .current_flat_idx = 0
2261
+ for j in range (self .ndim):
2262
+ self .current_flat_idx += self .start[j] * self .strides[j]
2258
2263
2259
2264
def __iter__ (self ):
2260
2265
return self
2261
2266
2262
2267
@ cython.boundscheck (False )
2263
2268
@ cython.wraparound (False )
2264
2269
def __next__ (self ):
2265
- cdef long long i, j, flat_idx
2270
+ cdef long long j, next_flat_idx
2271
+ cdef int extended_slice = 0
2266
2272
2273
+ # Check if we're done
2274
+ if self .done:
2275
+ if self .current_slice_start != - 1 :
2276
+ result = slice (self .current_slice_start, self .current_slice_end + 1 )
2277
+ self .current_slice_start = - 1
2278
+ return result
2279
+ raise StopIteration
2280
+
2281
+ # Initialize first slice point if needed
2282
+ if self .current_slice_start == - 1 :
2283
+ next_flat_idx = 0
2284
+ for j in range (self .ndim):
2285
+ next_flat_idx += (self .start[j] + self .indices[j]) * self .strides[j]
2286
+ self .current_slice_start = next_flat_idx
2287
+ self .current_slice_end = next_flat_idx
2288
+ self .current_flat_idx = next_flat_idx
2289
+ self .incr_indices()
2290
+
2291
+ # If we're done after the first element, return it
2292
+ if self .done:
2293
+ result = slice (self .current_slice_start, self .current_slice_end + 1 )
2294
+ self .current_slice_start = - 1
2295
+ return result
2296
+
2297
+ # Extend slice as long as indices remain contiguous
2267
2298
while not self .done:
2268
- flat_idx = 0
2299
+ # Calculate next flat index
2300
+ next_flat_idx = 0
2269
2301
for j in range (self .ndim):
2270
- flat_idx += (self .start[j] + self .indices[j]) * self .strides[j]
2302
+ next_flat_idx += (self .start[j] + self .indices[j]) * self .strides[j]
2271
2303
2272
- if self .current_slice_start == - 1 :
2273
- self .current_slice_start = flat_idx
2274
- self .current_slice_end = flat_idx
2275
- elif flat_idx == self .current_slice_end + 1 :
2276
- self .current_slice_end = flat_idx
2304
+ # If indices are contiguous, extend current slice
2305
+ if next_flat_idx == self .current_slice_end + 1 :
2306
+ self .current_slice_end = next_flat_idx
2307
+ self .current_flat_idx = next_flat_idx
2308
+ self .incr_indices()
2309
+ extended_slice = 1
2277
2310
else :
2311
+ # Non-contiguous index found, return current slice
2278
2312
result = slice (self .current_slice_start, self .current_slice_end + 1 )
2279
- self .current_slice_start = flat_idx
2280
- self .current_slice_end = flat_idx
2281
- # Increment the indices (TODO: remove duplicated code)
2313
+ self .current_slice_start = next_flat_idx
2314
+ self .current_slice_end = next_flat_idx
2315
+ self .current_flat_idx = next_flat_idx
2282
2316
self .incr_indices()
2283
2317
return result
2284
2318
2285
- # Increment the indices
2286
- self .incr_indices()
2287
-
2288
- if self .current_slice_start != - 1 :
2319
+ # If we've reached the end after extending the slice
2320
+ if extended_slice:
2289
2321
result = slice (self .current_slice_start, self .current_slice_end + 1 )
2290
2322
self .current_slice_start = - 1
2291
2323
return result
2292
2324
2325
+ # Should never reach here
2293
2326
raise StopIteration
2294
2327
2295
2328
@ cython.boundscheck (False )
0 commit comments