44the local memos_message_queue functionality in BaseScheduler.
55"""
66
7+ from typing import TYPE_CHECKING
8+
9+
10+ if TYPE_CHECKING :
11+ from collections .abc import Callable
12+
713from memos .log import get_logger
814from memos .mem_scheduler .general_modules .misc import AutoDroppingQueue as Queue
915from memos .mem_scheduler .schemas .message_schemas import ScheduleMessageItem
16+ from memos .mem_scheduler .schemas .task_schemas import DEFAULT_STREAM_KEY_PREFIX
17+ from memos .mem_scheduler .task_schedule_modules .orchestrator import SchedulerOrchestrator
18+ from memos .mem_scheduler .utils .status_tracker import TaskStatusTracker
1019from memos .mem_scheduler .webservice_modules .redis_service import RedisSchedulerModule
1120
1221
1625class SchedulerLocalQueue (RedisSchedulerModule ):
1726 def __init__ (
1827 self ,
19- maxsize : int ,
28+ maxsize : int = 0 ,
29+ stream_key_prefix : str = DEFAULT_STREAM_KEY_PREFIX ,
30+ orchestrator : SchedulerOrchestrator | None = None ,
31+ status_tracker : TaskStatusTracker | None = None ,
2032 ):
2133 """
2234 Initialize the SchedulerLocalQueue with a maximum queue size limit.
35+ Arguments match SchedulerRedisQueue for compatibility.
2336
2437 Args:
25- maxsize (int): Maximum number of messages allowed
26- in each individual queue .
27- If exceeded, subsequent puts will block
28- or raise an exception based on `block` parameter .
38+ maxsize (int): Maximum number of messages allowed in each individual queue.
39+ stream_key_prefix (str): Prefix for stream keys (simulated) .
40+ orchestrator: SchedulerOrchestrator instance (ignored).
41+ status_tracker: TaskStatusTracker instance (ignored) .
2942 """
3043 super ().__init__ ()
3144
32- self .stream_key_prefix = "local_queue"
45+ self .stream_key_prefix = stream_key_prefix or "local_queue"
3346
3447 self .max_internal_message_queue_size = maxsize
48+
3549 # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem]
3650 self .queue_streams : dict [str , Queue [ScheduleMessageItem ]] = {}
51+
52+ self .orchestrator = orchestrator
53+ self .status_tracker = status_tracker
54+
55+ self ._is_listening = False
56+ self ._message_handler : Callable [[ScheduleMessageItem ], None ] | None = None
57+
3758 logger .info (
38- f"SchedulerLocalQueue initialized with max_internal_message_queue_size={ maxsize } "
59+ f"SchedulerLocalQueue initialized with max_internal_message_queue_size={ self . max_internal_message_queue_size } "
3960 )
4061
4162 def get_stream_key (self , user_id : str , mem_cube_id : str , task_label : str ) -> str :
@@ -62,7 +83,7 @@ def put(
6283 Exception: Any underlying error during queue.put() operation.
6384 """
6485 stream_key = self .get_stream_key (
65- user_id = message .user_id , mem_cube_id = message .mem_cube_id , task_label = message .task_label
86+ user_id = message .user_id , mem_cube_id = message .mem_cube_id , task_label = message .label
6687 )
6788
6889 message .stream_key = stream_key
@@ -86,7 +107,7 @@ def get(
86107 stream_key : str ,
87108 block : bool = True ,
88109 timeout : float | None = None ,
89- batch_size : int | None = None ,
110+ batch_size : int | None = 1 ,
90111 ) -> list [ScheduleMessageItem ]:
91112 if batch_size is not None and batch_size <= 0 :
92113 logger .warning (
@@ -99,47 +120,85 @@ def get(
99120 logger .error (f"Stream { stream_key } does not exist when trying to get messages." )
100121 return []
101122
123+ # Ensure we always request a batch so we get a list back
124+ effective_batch_size = batch_size if batch_size is not None else 1
125+
102126 # Note: Assumes custom Queue implementation supports batch_size parameter
103127 res = self .queue_streams [stream_key ].get (
104- block = block , timeout = timeout , batch_size = batch_size
128+ block = block , timeout = timeout , batch_size = effective_batch_size
105129 )
106130 logger .debug (
107131 f"Retrieved { len (res )} messages from queue '{ stream_key } '. Current size: { self .queue_streams [stream_key ].qsize ()} "
108132 )
109133 return res
110134
111- def get_nowait (self , batch_size : int | None = None ) -> list [ScheduleMessageItem ]:
135+ def get_nowait (self , stream_key : str , batch_size : int | None = 1 ) -> list [ScheduleMessageItem ]:
112136 """
113- Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size).
137+ Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).
114138
115139 Returns immediately with available messages or an empty list if queue is empty.
116140
117141 Args:
142+ stream_key (str): The stream/queue identifier.
118143 batch_size (int | None): Number of messages to retrieve in a batch.
119144 If None, retrieves one message.
120145
121146 Returns:
122147 List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty.
123148 """
124- logger .debug (f"get_nowait() called with batch_size: { batch_size } " )
125- return self .get (block = False , batch_size = batch_size )
149+ logger .debug (f"get_nowait() called for { stream_key } with batch_size: { batch_size } " )
150+ return self .get (stream_key = stream_key , block = False , batch_size = batch_size )
151+
152+ def get_messages (self , batch_size : int ) -> list [ScheduleMessageItem ]:
153+ """
154+ Get messages from all streams in round-robin or sequential fashion.
155+ Equivalent to SchedulerRedisQueue.get_messages.
156+ """
157+ messages = []
158+ # Snapshot keys to avoid runtime modification issues
159+ stream_keys = list (self .queue_streams .keys ())
160+
161+ # Simple strategy: try to get up to batch_size messages across all streams
162+ # We can just iterate and collect.
163+
164+ # Calculate how many to get per stream to be fair?
165+ # Or just greedy? Redis implementation uses a complex logic.
166+ # For local, let's keep it simple: just iterate and take what's available (non-blocking)
167+
168+ for stream_key in stream_keys :
169+ if len (messages ) >= batch_size :
170+ break
171+
172+ needed = batch_size - len (messages )
173+ # Use get_nowait to avoid blocking
174+ fetched = self .get_nowait (stream_key = stream_key , batch_size = needed )
175+ messages .extend (fetched )
176+
177+ return messages
126178
127179 def qsize (self ) -> dict :
128180 """
129181 Return the current size of all internal queues as a dictionary.
130182
131183 Each key is the stream name, and each value is the number of messages in that queue.
184+ Also includes 'total_size'.
132185
133186 Returns:
134187 Dict[str, int]: Mapping from stream name to current queue size.
135188 """
136189 sizes = {stream : queue .qsize () for stream , queue in self .queue_streams .items ()}
190+ total_size = sum (sizes .values ())
191+ sizes ["total_size" ] = total_size
137192 logger .debug (f"Current queue sizes: { sizes } " )
138193 return sizes
139194
140- def clear (self ) -> None :
141- for queue in self .queue_streams .values ():
142- queue .clear ()
195+ def clear (self , stream_key : str | None = None ) -> None :
196+ if stream_key :
197+ if stream_key in self .queue_streams :
198+ self .queue_streams [stream_key ].clear ()
199+ else :
200+ for queue in self .queue_streams .values ():
201+ queue .clear ()
143202
144203 @property
145204 def unfinished_tasks (self ) -> int :
@@ -151,6 +210,50 @@ def unfinished_tasks(self) -> int:
151210 Returns:
152211 int: Sum of all message counts in all internal queues.
153212 """
154- total = sum (self .qsize ().values ())
213+ # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
214+ # But qsize() implementation above sums values from queue_streams, then adds total_size.
215+ # So sum(self.queue_streams.values().qsize()) is safer.
216+ total = sum (queue .qsize () for queue in self .queue_streams .values ())
155217 logger .debug (f"Total unfinished tasks across all queues: { total } " )
156218 return total
219+
220+ def get_stream_keys (self , stream_key_prefix : str | None = None ) -> list [str ]:
221+ """
222+ Return list of active stream keys.
223+ """
224+ prefix = stream_key_prefix or self .stream_key_prefix
225+ return [k for k in self .queue_streams if k .startswith (prefix )]
226+
227+ def size (self ) -> int :
228+ """
229+ Total size of all queues.
230+ """
231+ return sum (q .qsize () for q in self .queue_streams .values ())
232+
233+ def empty (self ) -> bool :
234+ """
235+ Check if all queues are empty.
236+ """
237+ return self .size () == 0
238+
239+ def full (self ) -> bool :
240+ """
241+ Check if any queue is full (approximate).
242+ """
243+ if self .max_internal_message_queue_size <= 0 :
244+ return False
245+ return any (
246+ q .qsize () >= self .max_internal_message_queue_size for q in self .queue_streams .values ()
247+ )
248+
249+ def ack_message (
250+ self ,
251+ user_id : str ,
252+ mem_cube_id : str ,
253+ task_label : str ,
254+ redis_message_id ,
255+ message : ScheduleMessageItem | None ,
256+ ) -> None :
257+ """
258+ Acknowledge a message (no-op for local queue as messages are popped immediately).
259+ """
0 commit comments