Skip to content

Commit ec22dca

Browse files
committed
refactor: remove unnecessary parameters for local queue
1 parent c0fc509 commit ec22dca

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

src/memos/mem_scheduler/task_schedule_modules/local_queue.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ def __init__(
2828
maxsize: int = 0,
2929
stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX,
3030
orchestrator: SchedulerOrchestrator | None = None,
31-
consumer_group: str = "scheduler_group",
32-
consumer_name: str | None = "scheduler_consumer",
3331
max_len: int | None = None,
34-
auto_delete_acked: bool = True,
3532
status_tracker: TaskStatusTracker | None = None,
3633
):
3734
"""
@@ -42,10 +39,7 @@ def __init__(
4239
maxsize (int): Maximum number of messages allowed in each individual queue.
4340
stream_key_prefix (str): Prefix for stream keys (simulated).
4441
orchestrator: SchedulerOrchestrator instance (ignored).
45-
consumer_group: Consumer group name (ignored).
46-
consumer_name: Consumer name (ignored).
4742
max_len: Alias for maxsize if maxsize is 0.
48-
auto_delete_acked: Whether to delete acked messages (ignored).
4943
status_tracker: TaskStatusTracker instance (ignored).
5044
"""
5145
super().__init__()
@@ -62,9 +56,6 @@ def __init__(
6256
self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {}
6357

6458
self.orchestrator = orchestrator
65-
self.consumer_group = consumer_group
66-
self.consumer_name = consumer_name
67-
self.auto_delete_acked = auto_delete_acked
6859
self.status_tracker = status_tracker
6960

7061
self._is_listening = False
@@ -122,7 +113,7 @@ def get(
122113
stream_key: str,
123114
block: bool = True,
124115
timeout: float | None = None,
125-
batch_size: int | None = None,
116+
batch_size: int | None = 1,
126117
) -> list[ScheduleMessageItem]:
127118
if batch_size is not None and batch_size <= 0:
128119
logger.warning(
@@ -135,18 +126,19 @@ def get(
135126
logger.error(f"Stream {stream_key} does not exist when trying to get messages.")
136127
return []
137128

129+
# Ensure we always request a batch so we get a list back
130+
effective_batch_size = batch_size if batch_size is not None else 1
131+
138132
# Note: Assumes custom Queue implementation supports batch_size parameter
139133
res = self.queue_streams[stream_key].get(
140-
block=block, timeout=timeout, batch_size=batch_size
134+
block=block, timeout=timeout, batch_size=effective_batch_size
141135
)
142136
logger.debug(
143137
f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}"
144138
)
145139
return res
146140

147-
def get_nowait(
148-
self, stream_key: str, batch_size: int | None = None
149-
) -> list[ScheduleMessageItem]:
141+
def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]:
150142
"""
151143
Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).
152144
@@ -195,17 +187,24 @@ def qsize(self) -> dict:
195187
Return the current size of all internal queues as a dictionary.
196188
197189
Each key is the stream name, and each value is the number of messages in that queue.
190+
Also includes 'total_size'.
198191
199192
Returns:
200193
Dict[str, int]: Mapping from stream name to current queue size.
201194
"""
202195
sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()}
196+
total_size = sum(sizes.values())
197+
sizes["total_size"] = total_size
203198
logger.debug(f"Current queue sizes: {sizes}")
204199
return sizes
205200

206-
def clear(self) -> None:
207-
for queue in self.queue_streams.values():
208-
queue.clear()
201+
def clear(self, stream_key: str | None = None) -> None:
202+
if stream_key:
203+
if stream_key in self.queue_streams:
204+
self.queue_streams[stream_key].clear()
205+
else:
206+
for queue in self.queue_streams.values():
207+
queue.clear()
209208

210209
@property
211210
def unfinished_tasks(self) -> int:
@@ -217,7 +216,10 @@ def unfinished_tasks(self) -> int:
217216
Returns:
218217
int: Sum of all message counts in all internal queues.
219218
"""
220-
total = sum(self.qsize().values())
219+
# qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
220+
# But qsize() implementation above sums values from queue_streams, then adds total_size.
221+
# So sum(self.queue_streams.values().qsize()) is safer.
222+
total = sum(queue.qsize() for queue in self.queue_streams.values())
221223
logger.debug(f"Total unfinished tasks across all queues: {total}")
222224
return total
223225

@@ -249,3 +251,15 @@ def full(self) -> bool:
249251
return any(
250252
q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values()
251253
)
254+
255+
def ack_message(
256+
self,
257+
user_id: str,
258+
mem_cube_id: str,
259+
task_label: str,
260+
redis_message_id,
261+
message: ScheduleMessageItem | None,
262+
) -> None:
263+
"""
264+
Acknowledge a message (no-op for local queue as messages are popped immediately).
265+
"""

0 commit comments

Comments
 (0)