@@ -95,6 +95,10 @@ def __init__(
9595 # the oldest partitions can be efficiently removed, maintaining the most recent partitions.
9696 self ._cursor_per_partition : OrderedDict [str , ConcurrentCursor ] = OrderedDict ()
9797 self ._semaphore_per_partition : OrderedDict [str , threading .Semaphore ] = OrderedDict ()
98+
99+ # Parent-state tracking: store each partition’s parent state in creation order
100+ self ._partition_parent_state_map : OrderedDict [str , Mapping [str , Any ]] = OrderedDict ()
101+
98102 self ._finished_partitions : set [str ] = set ()
99103 self ._lock = threading .Lock ()
100104 self ._timer = Timer ()
@@ -154,10 +158,32 @@ def close_partition(self, partition: Partition) -> None:
154158 and self ._semaphore_per_partition [partition_key ]._value == 0
155159 ):
156160 self ._update_global_cursor (cursor .state [self .cursor_field .cursor_field_key ])
157- self ._emit_state_message ()
161+
162+ self ._check_and_update_parent_state ()
163+
164+ self ._emit_state_message ()
158165
159166 self ._semaphore_per_partition [partition_key ].acquire ()
160167
168+ def _check_and_update_parent_state (self ) -> None :
169+ """
170+ If all slices for the earliest partitions are closed, pop them from the left
171+ of _partition_parent_state_map and update _parent_state to the most recent popped.
172+ """
173+ last_closed_state = None
174+ # We iterate in creation order (left to right) in the OrderedDict
175+ for p_key in list (self ._partition_parent_state_map .keys ()):
176+ # If this partition is not fully closed, stop
177+ if p_key not in self ._finished_partitions or self ._semaphore_per_partition [p_key ]._value != 0 :
178+ break
179+ # Otherwise, we pop from the left
180+ _ , closed_parent_state = self ._partition_parent_state_map .popitem (last = False )
181+ last_closed_state = closed_parent_state
182+
183+ # If we popped at least one partition, update the parent_state to that partition's parent state
184+ if last_closed_state is not None :
185+ self ._parent_state = last_closed_state
186+
161187 def ensure_at_least_one_state_emitted (self ) -> None :
162188 """
163189 The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
@@ -202,32 +228,39 @@ def stream_slices(self) -> Iterable[StreamSlice]:
202228
203229 slices = self ._partition_router .stream_slices ()
204230 self ._timer .start ()
205- for partition in slices :
206- yield from self ._generate_slices_from_partition (partition )
231+ for partition , last , parent_state in iterate_with_last_flag_and_state (
232+ slices , self ._partition_router .get_stream_state
233+ ):
234+ yield from self ._generate_slices_from_partition (partition , parent_state )
207235
208- def _generate_slices_from_partition (self , partition : StreamSlice ) -> Iterable [StreamSlice ]:
236+ def _generate_slices_from_partition (self , partition : StreamSlice , parent_state : Mapping [ str , Any ] ) -> Iterable [StreamSlice ]:
209237 # Ensure the maximum number of partitions is not exceeded
210238 self ._ensure_partition_limit ()
211239
240+ partition_key = self ._to_partition_key (partition .partition )
241+
212242 cursor = self ._cursor_per_partition .get (self ._to_partition_key (partition .partition ))
213243 if not cursor :
214244 cursor = self ._create_cursor (
215245 self ._global_cursor ,
216246 self ._lookback_window if self ._global_cursor else 0 ,
217247 )
218248 with self ._lock :
219- self ._cursor_per_partition [self . _to_partition_key ( partition . partition ) ] = cursor
220- self ._semaphore_per_partition [self . _to_partition_key ( partition . partition ) ] = (
249+ self ._cursor_per_partition [partition_key ] = cursor
250+ self ._semaphore_per_partition [partition_key ] = (
221251 threading .Semaphore (0 )
222252 )
223253
254+ with self ._lock :
255+ self ._partition_parent_state_map [partition_key ] = deepcopy (parent_state )
256+
224257 for cursor_slice , is_last_slice , _ in iterate_with_last_flag_and_state (
225258 cursor .stream_slices (),
226259 lambda : None ,
227260 ):
228- self ._semaphore_per_partition [self . _to_partition_key ( partition . partition ) ].release ()
261+ self ._semaphore_per_partition [partition_key ].release ()
229262 if is_last_slice :
230- self ._finished_partitions .add (self . _to_partition_key ( partition . partition ) )
263+ self ._finished_partitions .add (partition_key )
231264 yield StreamSlice (
232265 partition = partition , cursor_slice = cursor_slice , extra_fields = partition .extra_fields
233266 )
0 commit comments