Skip to content

Commit aff2932

Browse files
committed
refactor: refactor deep search feature, now only allowing one-round deep search
1 parent 2097eae commit aff2932

File tree

7 files changed

+266
-57
lines changed

7 files changed

+266
-57
lines changed

src/memos/api/handlers/component_init.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,6 @@ def init_server() -> dict[str, Any]:
237237
# Initialize SchedulerAPIModule
238238
api_module = mem_scheduler.api_module
239239

240-
# TODO: must remove!
241-
mem_scheduler.memos_message_queue.debug_mode_on()
242-
243240
# Start scheduler if enabled
244241
if os.getenv("API_SCHEDULER_ON", "true").lower() == "true":
245242
mem_scheduler.start()

src/memos/api/handlers/search_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _deep_search(
224224
"chat_history": search_req.chat_history,
225225
}
226226

227-
return self.searcher.deep_search(
227+
enhanced_memories = self.searcher.deep_search(
228228
query=search_req.query,
229229
user_name=user_context.mem_cube_id,
230230
top_k=search_req.top_k,
@@ -234,12 +234,14 @@ def _deep_search(
234234
search_filter=search_filter,
235235
info=info,
236236
)
237+
formatted_memories = [format_memory_item(data) for data in enhanced_memories]
238+
return formatted_memories
237239

238240
def _fine_search(
239241
self,
240242
search_req: APISearchRequest,
241243
user_context: UserContext,
242-
) -> list[str]:
244+
) -> list:
243245
"""
244246
Fine-grained search with query enhancement.
245247

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
4343
from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
44+
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
4445
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
4546
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
4647
from memos.mem_scheduler.utils.db_utils import get_utc_now
@@ -824,25 +825,65 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di
824825

825826
return result
826827

828+
@staticmethod
829+
def init_task_status():
830+
return {
831+
"running": 0,
832+
"remaining": 0,
833+
"completed": 0,
834+
}
835+
836+
def get_tasks_status(self):
837+
task_status = self.init_task_status()
838+
memos_message_queue = self.memos_message_queue.memos_message_queue
839+
if isinstance(memos_message_queue, SchedulerRedisQueue):
840+
stream_keys = memos_message_queue.get_stream_keys(
841+
stream_key_prefix=memos_message_queue.stream_key_prefix
842+
)
843+
for stream_key in stream_keys:
844+
if stream_key not in task_status:
845+
task_status[stream_key] = self.init_task_status()
846+
# For Redis queue, prefer XINFO GROUPS to compute pending
847+
groups_info = memos_message_queue.redis.xinfo_groups(stream_key)
848+
if groups_info:
849+
for group in groups_info:
850+
if group.get("name") == memos_message_queue.consumer_group:
851+
task_status[stream_key]["running"] += int(group.get("pending", 0))
852+
task_status[stream_key]["remaining"] += int(group.get("remaining", 0))
853+
task_status["running"] += int(group.get("pending", 0))
854+
task_status["remaining"] += int(group.get("remaining", 0))
855+
break
856+
857+
elif isinstance(memos_message_queue, SchedulerLocalQueue):
858+
running_task_count = self.dispatcher.get_running_task_count()
859+
task_status["running"] = running_task_count
860+
task_status["remaining"] = sum(memos_message_queue.qsize().values())
861+
else:
862+
logger.error(
863+
f"type of self.memos_message_queue is {memos_message_queue}, which is not supported"
864+
)
865+
raise NotImplementedError()
866+
827867
def mem_scheduler_wait(
828868
self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01
829869
) -> bool:
830870
"""
831871
Uses EWMA throughput, detects leaked `unfinished_tasks`, and waits for dispatcher.
832872
"""
833873
deadline = time.monotonic() + timeout
874+
memos_message_queue = self.memos_message_queue.memos_message_queue
834875

835876
# --- helpers (local, no external deps) ---
836877
def _unfinished() -> int:
837878
"""Prefer `unfinished_tasks`; fallback to `qsize()`."""
838879
try:
839-
u = getattr(self.memos_message_queue, "unfinished_tasks", None)
880+
u = getattr(memos_message_queue, "unfinished_tasks", None)
840881
if u is not None:
841882
return int(u)
842883
except Exception:
843884
pass
844885
try:
845-
return int(self.memos_message_queue.qsize())
886+
return int(memos_message_queue.qsize())
846887
except Exception:
847888
return 0
848889

@@ -876,7 +917,7 @@ def _fmt_eta(seconds: float | None) -> str:
876917
# 1) read counters
877918
curr_unfinished = _unfinished()
878919
try:
879-
qsz = int(self.memos_message_queue.qsize())
920+
qsz = int(memos_message_queue.qsize())
880921
except Exception:
881922
qsz = -1
882923

@@ -892,14 +933,14 @@ def _fmt_eta(seconds: float | None) -> str:
892933
except Exception:
893934
pass
894935

895-
if isinstance(self.memos_message_queue, SchedulerRedisQueue):
936+
if isinstance(memos_message_queue, SchedulerRedisQueue):
896937
# For Redis queue, prefer XINFO GROUPS to compute pending
897-
groups_info = self.memos_message_queue.redis.xinfo_groups(
898-
self.memos_message_queue.stream_key_prefix
938+
groups_info = memos_message_queue.redis.xinfo_groups(
939+
memos_message_queue.stream_key_prefix
899940
)
900941
if groups_info:
901942
for group in groups_info:
902-
if group.get("name") == self.memos_message_queue.consumer_group:
943+
if group.get("name") == memos_message_queue.consumer_group:
903944
pend = int(group.get("pending", pend))
904945
break
905946
else:
@@ -975,18 +1016,19 @@ def _fmt_eta(seconds: float | None) -> str:
9751016

9761017
def _gather_queue_stats(self) -> dict:
9771018
"""Collect queue/dispatcher stats for reporting."""
1019+
memos_message_queue = self.memos_message_queue.memos_message_queue
9781020
stats: dict[str, int | float | str] = {}
9791021
stats["use_redis_queue"] = bool(self.use_redis_queue)
9801022
# local queue metrics
9811023
if not self.use_redis_queue:
9821024
try:
983-
stats["qsize"] = int(self.memos_message_queue.qsize())
1025+
stats["qsize"] = int(memos_message_queue.qsize())
9841026
except Exception:
9851027
stats["qsize"] = -1
9861028
# unfinished_tasks if available
9871029
try:
9881030
stats["unfinished_tasks"] = int(
989-
getattr(self.memos_message_queue, "unfinished_tasks", 0) or 0
1031+
getattr(memos_message_queue, "unfinished_tasks", 0) or 0
9901032
)
9911033
except Exception:
9921034
stats["unfinished_tasks"] = -1

src/memos/mem_scheduler/task_schedule_modules/redis_queue.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
the local memos_message_queue functionality in BaseScheduler.
66
"""
77

8+
import os
89
import re
910
import time
1011

@@ -33,7 +34,9 @@ class SchedulerRedisQueue(RedisSchedulerModule):
3334

3435
def __init__(
3536
self,
36-
stream_key_prefix: str = "scheduler:messages:stream",
37+
stream_key_prefix: str = os.getenv(
38+
"MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", "scheduler:messages:stream"
39+
),
3740
consumer_group: str = "scheduler_group",
3841
consumer_name: str | None = "scheduler_consumer",
3942
max_len: int = 10000,
@@ -283,7 +286,7 @@ def qsize(self) -> int:
283286
logger.error(f"Failed to get Redis queue size: {e}")
284287
return 0
285288

286-
def get_stream_keys(self) -> list[str]:
289+
def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]:
287290
"""
288291
List all Redis stream keys that match this queue's prefix.
289292
@@ -293,16 +296,18 @@ def get_stream_keys(self) -> list[str]:
293296
if not self._redis_conn:
294297
return []
295298

299+
if stream_key_prefix is None:
300+
stream_key_prefix = self.stream_key_prefix
296301
# First, get all keys that might match (using Redis pattern matching)
297-
redis_pattern = f"{self.stream_key_prefix}:*"
302+
redis_pattern = f"{stream_key_prefix}:*"
298303
raw_keys = [
299304
key.decode("utf-8") if isinstance(key, bytes) else key
300305
for key in self._redis_conn.scan_iter(match=redis_pattern)
301306
]
302307

303308
# Second, filter using Python regex to ensure exact prefix match
304309
# Escape special regex characters in the prefix, then add :.*
305-
escaped_prefix = re.escape(self.stream_key_prefix)
310+
escaped_prefix = re.escape(stream_key_prefix)
306311
regex_pattern = f"^{escaped_prefix}:"
307312
stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)]
308313

0 commit comments

Comments
 (0)