@@ -28,10 +28,7 @@ def __init__(
2828 maxsize : int = 0 ,
2929 stream_key_prefix : str = DEFAULT_STREAM_KEY_PREFIX ,
3030 orchestrator : SchedulerOrchestrator | None = None ,
31- consumer_group : str = "scheduler_group" ,
32- consumer_name : str | None = "scheduler_consumer" ,
3331 max_len : int | None = None ,
34- auto_delete_acked : bool = True ,
3532 status_tracker : TaskStatusTracker | None = None ,
3633 ):
3734 """
@@ -42,10 +39,7 @@ def __init__(
4239 maxsize (int): Maximum number of messages allowed in each individual queue.
4340 stream_key_prefix (str): Prefix for stream keys (simulated).
4441 orchestrator: SchedulerOrchestrator instance (ignored).
45- consumer_group: Consumer group name (ignored).
46- consumer_name: Consumer name (ignored).
4742 max_len: Alias for maxsize if maxsize is 0.
48- auto_delete_acked: Whether to delete acked messages (ignored).
4943 status_tracker: TaskStatusTracker instance (ignored).
5044 """
5145 super ().__init__ ()
@@ -62,9 +56,6 @@ def __init__(
6256 self .queue_streams : dict [str , Queue [ScheduleMessageItem ]] = {}
6357
6458 self .orchestrator = orchestrator
65- self .consumer_group = consumer_group
66- self .consumer_name = consumer_name
67- self .auto_delete_acked = auto_delete_acked
6859 self .status_tracker = status_tracker
6960
7061 self ._is_listening = False
@@ -122,7 +113,7 @@ def get(
122113 stream_key : str ,
123114 block : bool = True ,
124115 timeout : float | None = None ,
125- batch_size : int | None = None ,
116+ batch_size : int | None = 1 ,
126117 ) -> list [ScheduleMessageItem ]:
127118 if batch_size is not None and batch_size <= 0 :
128119 logger .warning (
@@ -135,18 +126,19 @@ def get(
135126 logger .error (f"Stream { stream_key } does not exist when trying to get messages." )
136127 return []
137128
129+ # Ensure we always request a batch so we get a list back
130+ effective_batch_size = batch_size if batch_size is not None else 1
131+
138132 # Note: Assumes custom Queue implementation supports batch_size parameter
139133 res = self .queue_streams [stream_key ].get (
140- block = block , timeout = timeout , batch_size = batch_size
134+ block = block , timeout = timeout , batch_size = effective_batch_size
141135 )
142136 logger .debug (
143137 f"Retrieved { len (res )} messages from queue '{ stream_key } '. Current size: { self .queue_streams [stream_key ].qsize ()} "
144138 )
145139 return res
146140
147- def get_nowait (
148- self , stream_key : str , batch_size : int | None = None
149- ) -> list [ScheduleMessageItem ]:
141+ def get_nowait (self , stream_key : str , batch_size : int | None = 1 ) -> list [ScheduleMessageItem ]:
150142 """
151143 Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).
152144
@@ -195,17 +187,24 @@ def qsize(self) -> dict:
195187 Return the current size of all internal queues as a dictionary.
196188
197189 Each key is the stream name, and each value is the number of messages in that queue.
190+ Also includes 'total_size'.
198191
199192 Returns:
200193 Dict[str, int]: Mapping from stream name to current queue size.
201194 """
202195 sizes = {stream : queue .qsize () for stream , queue in self .queue_streams .items ()}
196+ total_size = sum (sizes .values ())
197+ sizes ["total_size" ] = total_size
203198 logger .debug (f"Current queue sizes: { sizes } " )
204199 return sizes
205200
206- def clear (self ) -> None :
207- for queue in self .queue_streams .values ():
208- queue .clear ()
201+ def clear (self , stream_key : str | None = None ) -> None :
202+ if stream_key :
203+ if stream_key in self .queue_streams :
204+ self .queue_streams [stream_key ].clear ()
205+ else :
206+ for queue in self .queue_streams .values ():
207+ queue .clear ()
209208
210209 @property
211210 def unfinished_tasks (self ) -> int :
@@ -217,7 +216,10 @@ def unfinished_tasks(self) -> int:
217216 Returns:
218217 int: Sum of all message counts in all internal queues.
219218 """
220- total = sum (self .qsize ().values ())
219+ # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
220+ # But qsize() implementation above sums values from queue_streams, then adds total_size.
221+ # So sum(self.queue_streams.values().qsize()) is safer.
222+ total = sum (queue .qsize () for queue in self .queue_streams .values ())
221223 logger .debug (f"Total unfinished tasks across all queues: { total } " )
222224 return total
223225
@@ -249,3 +251,15 @@ def full(self) -> bool:
249251 return any (
250252 q .qsize () >= self .max_internal_message_queue_size for q in self .queue_streams .values ()
251253 )
254+
255+ def ack_message (
256+ self ,
257+ user_id : str ,
258+ mem_cube_id : str ,
259+ task_label : str ,
260+ redis_message_id ,
261+ message : ScheduleMessageItem | None ,
262+ ) -> None :
263+ """
264+ Acknowledge a message (no-op for local queue as messages are popped immediately).
265+ """
0 commit comments