@@ -28,6 +28,9 @@ class GroupingPartitionRouter(PartitionRouter):
2828 config : Config
2929 deduplicate : bool = True
3030
31+ def __post_init__ (self ) -> None :
32+ self ._state : Optional [Mapping [str , StreamState ]] = {}
33+
3134 def stream_slices (self ) -> Iterable [StreamSlice ]:
3235 """
3336 Lazily groups partitions from the underlying partition router into batches of size `group_size`.
@@ -58,9 +61,11 @@ def stream_slices(self) -> Iterable[StreamSlice]:
5861
5962 # Yield the batch when it reaches the group_size
6063 if len (batch ) == self .group_size :
64+ self ._state = self .underlying_partition_router .get_stream_state ()
6165 yield self ._create_grouped_slice (batch )
6266 batch = [] # Reset the batch
6367
68+ self ._state = self .underlying_partition_router .get_stream_state ()
6469 # Yield any remaining partitions if the batch isn't empty
6570 if batch :
6671 yield self ._create_grouped_slice (batch )
@@ -130,7 +135,8 @@ def get_request_body_json(
130135 def set_initial_state (self , stream_state : StreamState ) -> None :
131136 """Delegate state initialization to the underlying partition router."""
132137 self .underlying_partition_router .set_initial_state (stream_state )
138+ self ._state = self .underlying_partition_router .get_stream_state ()
133139
134140 def get_stream_state (self ) -> Optional [Mapping [str , StreamState ]]:
135141 """Delegate state retrieval to the underlying partition router."""
136- return self .underlying_partition_router . get_stream_state ()
142+ return self ._state
0 commit comments