|
6 | 6 | import logging |
7 | 7 | import threading |
8 | 8 | import time |
9 | | -from collections import OrderedDict |
| 9 | +from collections import deque, OrderedDict |
10 | 10 | from copy import deepcopy |
11 | 11 | from datetime import timedelta |
12 | 12 | from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional |
@@ -99,9 +99,13 @@ def __init__( |
99 | 99 | self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict() |
100 | 100 |
|
101 | 101 | # Parent-state tracking: store each partition’s parent state in creation order |
102 | | - self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict() |
| 102 | + self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = OrderedDict() |
103 | 103 |
|
104 | 104 | self._finished_partitions: set[str] = set() |
| 105 | + self._open_seqs: deque[int] = deque() |
| 106 | + self._next_seq: int = 0 |
| 107 | + self._seq_by_partition: dict[str, int] = {} |
| 108 | + |
105 | 109 | self._lock = threading.Lock() |
106 | 110 | self._timer = Timer() |
107 | 111 | self._new_global_cursor: Optional[StreamState] = None |
@@ -162,55 +166,28 @@ def close_partition(self, partition: Partition) -> None: |
162 | 166 | ): |
163 | 167 | self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key]) |
164 | 168 |
|
| 169 | + # Clean up the partition if it is fully processed |
| 170 | + self._cleanup_if_done(partition_key) |
| 171 | + |
165 | 172 | self._check_and_update_parent_state() |
166 | 173 |
|
167 | 174 | self._emit_state_message() |
168 | 175 |
|
169 | 176 | def _check_and_update_parent_state(self) -> None: |
170 | | - """ |
171 | | - Pop the leftmost partition state from _partition_parent_state_map only if |
172 | | - *all partitions* up to (and including) that partition key in _semaphore_per_partition |
173 | | - are fully finished (i.e. in _finished_partitions and semaphore._value == 0). |
174 | | - Additionally, delete finished semaphores with a value of 0 to free up memory, |
175 | | - as they are only needed to track errors and completion status. |
176 | | - """ |
177 | 177 | last_closed_state = None |
178 | 178 |
|
179 | 179 | while self._partition_parent_state_map: |
180 | | - # Look at the earliest partition key in creation order |
181 | | - earliest_key = next(iter(self._partition_parent_state_map)) |
182 | | - |
183 | | - # Verify ALL partitions from the left up to earliest_key are finished |
184 | | - all_left_finished = True |
185 | | - for p_key, sem in list( |
186 | | - self._semaphore_per_partition.items() |
187 | | - ): # Use list to allow modification during iteration |
188 | | - # If any earlier partition is still not finished, we must stop |
189 | | - if p_key not in self._finished_partitions or sem._value != 0: |
190 | | - all_left_finished = False |
191 | | - break |
192 | | - # Once we've reached earliest_key in the semaphore order, we can stop checking |
193 | | - if p_key == earliest_key: |
194 | | - break |
195 | | - |
196 | | - # If the partitions up to earliest_key are not all finished, break the while-loop |
197 | | - if not all_left_finished: |
198 | | - break |
| 180 | + earliest_key, (candidate_state, candidate_seq) = \ |
| 181 | + next(iter(self._partition_parent_state_map.items())) |
199 | 182 |
|
200 | | - # Pop the leftmost entry from parent-state map |
201 | | - _, closed_parent_state = self._partition_parent_state_map.popitem(last=False) |
202 | | - last_closed_state = closed_parent_state |
| 183 | + # if any partition that started <= candidate_seq is still open, we must wait |
| 184 | + if self._open_seqs and self._open_seqs[0] <= candidate_seq: |
| 185 | + break |
203 | 186 |
|
204 | | - # Clean up finished semaphores with value 0 up to and including earliest_key |
205 | | - for p_key in list(self._semaphore_per_partition.keys()): |
206 | | - sem = self._semaphore_per_partition[p_key] |
207 | | - if p_key in self._finished_partitions and sem._value == 0: |
208 | | - del self._semaphore_per_partition[p_key] |
209 | | - logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0") |
210 | | - if p_key == earliest_key: |
211 | | - break |
| 187 | + # safe to pop |
| 188 | + self._partition_parent_state_map.popitem(last=False) |
| 189 | + last_closed_state = candidate_state |
212 | 190 |
|
213 | | - # Update _parent_state if we popped at least one partition |
214 | 191 | if last_closed_state is not None: |
215 | 192 | self._parent_state = last_closed_state |
216 | 193 |
|
@@ -293,14 +270,19 @@ def _generate_slices_from_partition( |
293 | 270 | self._semaphore_per_partition[partition_key] = threading.Semaphore(0) |
294 | 271 |
|
295 | 272 | with self._lock: |
| 273 | + seq = self._next_seq |
| 274 | + self._next_seq += 1 |
| 275 | + self._open_seqs.append(seq) |
| 276 | + self._seq_by_partition[partition_key] = seq |
| 277 | + |
296 | 278 | if ( |
297 | 279 | len(self._partition_parent_state_map) == 0 |
298 | 280 | or self._partition_parent_state_map[ |
299 | 281 | next(reversed(self._partition_parent_state_map)) |
300 | | - ] |
| 282 | + ][0] |
301 | 283 | != parent_state |
302 | 284 | ): |
303 | | - self._partition_parent_state_map[partition_key] = deepcopy(parent_state) |
| 285 | + self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq) |
304 | 286 |
|
305 | 287 | for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state( |
306 | 288 | cursor.stream_slices(), |
@@ -338,10 +320,7 @@ def _ensure_partition_limit(self) -> None: |
338 | 320 | while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1: |
339 | 321 | # Try removing finished partitions first |
340 | 322 | for partition_key in list(self._cursor_per_partition.keys()): |
341 | | - if partition_key in self._finished_partitions and ( |
342 | | - partition_key not in self._semaphore_per_partition |
343 | | - or self._semaphore_per_partition[partition_key]._value == 0 |
344 | | - ): |
| 323 | + if partition_key not in self._seq_by_partition: |
345 | 324 | oldest_partition = self._cursor_per_partition.pop( |
346 | 325 | partition_key |
347 | 326 | ) # Remove the oldest partition |
@@ -474,6 +453,25 @@ def _update_global_cursor(self, value: Any) -> None: |
474 | 453 | ): |
475 | 454 | self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)} |
476 | 455 |
|
| 456 | + def _cleanup_if_done(self, partition_key: str) -> None: |
| 457 | + """ |
| 458 | + Free every in-memory structure that belonged to a completed partition: |
| 459 | + cursor, semaphore, flag inside `_finished_partitions` |
| 460 | + """ |
| 461 | + if not ( |
| 462 | + partition_key in self._finished_partitions |
| 463 | + and self._semaphore_per_partition[partition_key]._value == 0 |
| 464 | + ): |
| 465 | + return |
| 466 | + |
| 467 | + self._semaphore_per_partition.pop(partition_key, None) |
| 468 | + self._finished_partitions.discard(partition_key) |
| 469 | + |
| 470 | + seq = self._seq_by_partition.pop(partition_key) |
| 471 | + self._open_seqs.remove(seq) |
| 472 | + |
| 473 | + logger.debug(f"Partition {partition_key} fully processed and cleaned up.") |
| 474 | + |
477 | 475 | def _to_partition_key(self, partition: Mapping[str, Any]) -> str: |
478 | 476 | return self._partition_serializer.to_partition_key(partition) |
479 | 477 |
|
|
0 commit comments