Skip to content

Commit 232be6f

Browse files
committed
refactor: add searcher to handler_init; remove info log from task_queue
1 parent d1a7261 commit 232be6f

File tree

7 files changed

+71
-34
lines changed

7 files changed

+71
-34
lines changed

examples/mem_scheduler/api_w_scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def my_test_handler(messages: list[ScheduleMessageItem]):
2525
print(f"My test handler received {len(messages)} messages:")
2626
for msg in messages:
2727
print(f" my_test_handler - {msg.item_id}: {msg.content}")
28+
user_status_running = handle_scheduler_status(
29+
user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
30+
)
31+
print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running)
2832

2933

3034
# 2. Register the handler
@@ -56,10 +60,6 @@ def my_test_handler(messages: list[ScheduleMessageItem]):
5660

5761
# 5.1 Monitor status for specific mem_cube while running
5862
USER_MEM_CUBE = "test_mem_cube"
59-
user_status_running = handle_scheduler_status(
60-
user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
61-
)
62-
print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running)
6363

6464
# 6. Wait for messages to be processed (limited to 100 checks)
6565
print("Waiting for messages to be consumed (max 100 checks)...")

src/memos/api/handlers/base_handler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from memos.log import get_logger
1111
from memos.mem_scheduler.base_scheduler import BaseScheduler
12+
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
1213

1314

1415
logger = get_logger(__name__)
@@ -28,6 +29,7 @@ def __init__(
2829
naive_mem_cube: Any | None = None,
2930
mem_reader: Any | None = None,
3031
mem_scheduler: Any | None = None,
32+
searcher: Any | None = None,
3133
embedder: Any | None = None,
3234
reranker: Any | None = None,
3335
graph_db: Any | None = None,
@@ -58,6 +60,7 @@ def __init__(
5860
self.naive_mem_cube = naive_mem_cube
5961
self.mem_reader = mem_reader
6062
self.mem_scheduler = mem_scheduler
63+
self.searcher = searcher
6164
self.embedder = embedder
6265
self.reranker = reranker
6366
self.graph_db = graph_db
@@ -128,6 +131,11 @@ def mem_scheduler(self) -> BaseScheduler:
128131
"""Get scheduler instance."""
129132
return self.deps.mem_scheduler
130133

134+
@property
135+
def searcher(self) -> Searcher:
136+
"""Get scheduler instance."""
137+
return self.deps.searcher
138+
131139
@property
132140
def embedder(self):
133141
"""Get embedder instance."""

src/memos/api/handlers/component_init.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
including databases, LLMs, memory systems, and schedulers.
66
"""
77

8+
import os
9+
810
from typing import TYPE_CHECKING, Any
911

1012
from memos.api.config import APIConfig
@@ -38,6 +40,10 @@
3840
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
3941
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
4042
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
43+
44+
45+
if TYPE_CHECKING:
46+
from memos.memories.textual.tree import TreeTextMemory
4147
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
4248
InternetRetrieverFactory,
4349
)
@@ -47,7 +53,7 @@
4753

4854
if TYPE_CHECKING:
4955
from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
50-
56+
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
5157
logger = get_logger(__name__)
5258

5359

@@ -205,6 +211,13 @@ def init_server() -> dict[str, Any]:
205211

206212
logger.debug("MemCube created")
207213

214+
tree_mem: TreeTextMemory = naive_mem_cube.text_mem
215+
searcher: Searcher = tree_mem.get_searcher(
216+
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
217+
moscube=False,
218+
)
219+
logger.debug("Searcher created")
220+
208221
# Initialize Scheduler
209222
scheduler_config_dict = APIConfig.get_scheduler_config()
210223
scheduler_config = SchedulerConfigFactory(
@@ -217,16 +230,14 @@ def init_server() -> dict[str, Any]:
217230
db_engine=BaseDBManager.create_default_sqlite_engine(),
218231
mem_reader=mem_reader,
219232
)
220-
mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube)
233+
mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher)
221234
logger.debug("Scheduler initialized")
222235

223236
# Initialize SchedulerAPIModule
224237
api_module = mem_scheduler.api_module
225238

226239
# Start scheduler if enabled
227-
import os
228-
229-
if os.getenv("API_SCHEDULER_ON", True):
240+
if os.getenv("API_SCHEDULER_ON", "true").lower() == "true":
230241
mem_scheduler.start()
231242
logger.info("Scheduler started")
232243

@@ -253,6 +264,7 @@ def init_server() -> dict[str, Any]:
253264
"mos_server": mos_server,
254265
"mem_scheduler": mem_scheduler,
255266
"naive_mem_cube": naive_mem_cube,
267+
"searcher": searcher,
256268
"api_module": api_module,
257269
"vector_db": vector_db,
258270
"pref_extractor": pref_extractor,

src/memos/api/handlers/search_handler.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from memos.api.product_models import APISearchRequest, SearchResponse
1919
from memos.context.context import ContextThreadPoolExecutor
2020
from memos.log import get_logger
21-
from memos.mem_scheduler.schemas.general_schemas import SearchMode
21+
from memos.mem_scheduler.schemas.general_schemas import FINE_STRATEGY, FineStrategy, SearchMode
2222
from memos.types import MOSSearchResult, UserContext
2323

2424

@@ -40,7 +40,7 @@ def __init__(self, dependencies: HandlerDependencies):
4040
dependencies: HandlerDependencies instance
4141
"""
4242
super().__init__(dependencies)
43-
self._validate_dependencies("naive_mem_cube", "mem_scheduler")
43+
self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher")
4444

4545
def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse:
4646
"""
@@ -211,11 +211,17 @@ def _fast_search(
211211

212212
return formatted_memories
213213

214+
def _deep_search(
215+
self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int
216+
) -> list:
217+
logger.error("waiting to be implemented")
218+
return []
219+
214220
def _fine_search(
215221
self,
216222
search_req: APISearchRequest,
217223
user_context: UserContext,
218-
) -> list:
224+
) -> list[str]:
219225
"""
220226
Fine-grained search with query enhancement.
221227
@@ -226,19 +232,22 @@ def _fine_search(
226232
Returns:
227233
List of enhanced search results
228234
"""
235+
if FINE_STRATEGY == FineStrategy.DEEP_SEARCH:
236+
return self._deep_search(
237+
search_req=search_req, user_context=user_context, max_thinking_depth=3
238+
)
239+
229240
target_session_id = search_req.session_id or "default_session"
230241
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
231242

232-
searcher = self.mem_scheduler.searcher
233-
234243
info = {
235244
"user_id": search_req.user_id,
236245
"session_id": target_session_id,
237246
"chat_history": search_req.chat_history,
238247
}
239248

240249
# Fine retrieve
241-
fast_retrieved_memories = searcher.retrieve(
250+
raw_retrieved_memories = self.searcher.retrieve(
242251
query=search_req.query,
243252
user_name=user_context.mem_cube_id,
244253
top_k=search_req.top_k,
@@ -250,8 +259,8 @@ def _fine_search(
250259
)
251260

252261
# Post retrieve
253-
fast_memories = searcher.post_retrieve(
254-
retrieved_results=fast_retrieved_memories,
262+
raw_memories = self.searcher.post_retrieve(
263+
retrieved_results=raw_retrieved_memories,
255264
top_k=search_req.top_k,
256265
user_name=user_context.mem_cube_id,
257266
info=info,
@@ -260,22 +269,22 @@ def _fine_search(
260269
# Enhance with query
261270
enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query(
262271
query_history=[search_req.query],
263-
memories=fast_memories,
272+
memories=raw_memories,
264273
)
265274

266-
if len(enhanced_memories) < len(fast_memories):
275+
if len(enhanced_memories) < len(raw_memories):
267276
logger.info(
268-
f"Enhanced memories ({len(enhanced_memories)}) are less than fast memories ({len(fast_memories)}). Recalling for more."
277+
f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more."
269278
)
270279
missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories(
271280
query=search_req.query,
272-
memories=fast_memories,
281+
memories=raw_memories,
273282
)
274-
retrieval_size = len(fast_memories) - len(enhanced_memories)
283+
retrieval_size = len(raw_memories) - len(enhanced_memories)
275284
logger.info(f"Retrieval size: {retrieval_size}")
276285
if trigger:
277286
logger.info(f"Triggering additional search with hint: {missing_info_hint}")
278-
additional_memories = searcher.search(
287+
additional_memories = self.searcher.search(
279288
query=missing_info_hint,
280289
user_name=user_context.mem_cube_id,
281290
top_k=retrieval_size,
@@ -286,7 +295,7 @@ def _fine_search(
286295
)
287296
else:
288297
logger.info("Not triggering additional search, using fast memories.")
289-
additional_memories = fast_memories[:retrieval_size]
298+
additional_memories = raw_memories[:retrieval_size]
290299

291300
enhanced_memories += additional_memories
292301
logger.info(

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@
5454
from memos.memories.activation.kv import KVCacheMemory
5555
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
5656
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
57+
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
5758
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
5859

5960

6061
if TYPE_CHECKING:
61-
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
6262
from memos.reranker.http_bge import HTTPBGEReranker
6363

6464

@@ -141,14 +141,21 @@ def __init__(self, config: BaseSchedulerConfig):
141141
self.auth_config = None
142142
self.rabbitmq_config = None
143143

144-
def init_mem_cube(self, mem_cube):
144+
def init_mem_cube(
145+
self,
146+
mem_cube: BaseMemCube,
147+
searcher: Searcher | None = None,
148+
):
145149
self.mem_cube = mem_cube
146150
self.text_mem: TreeTextMemory = self.mem_cube.text_mem
147-
self.searcher: Searcher = self.text_mem.get_searcher(
148-
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
149-
moscube=False,
150-
)
151151
self.reranker: HTTPBGEReranker = self.text_mem.reranker
152+
if searcher is None:
153+
self.searcher: Searcher = self.text_mem.get_searcher(
154+
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
155+
moscube=False,
156+
)
157+
else:
158+
self.searcher = searcher
152159

153160
def initialize_modules(
154161
self,

src/memos/mem_scheduler/schemas/general_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FineStrategy(str, Enum):
1818

1919
REWRITE = "rewrite"
2020
RECREATE = "recreate"
21+
DEEP_SEARCH = "deep_search"
2122

2223

2324
FILE_PATH = Path(__file__).absolute()

src/memos/mem_scheduler/task_schedule_modules/task_queue.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
112112
)
113113

114114
messages.extend(fetched)
115-
116-
logger.info(
117-
f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}"
118-
)
115+
if len(messages) > 0:
116+
logger.debug(
117+
f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}"
118+
)
119119
return messages
120120

121121
def clear(self):

0 commit comments

Comments
 (0)