Skip to content

Commit ab72f1c

Browse files
committed
feat: add more log in tree text mem retriever
1 parent a172146 commit ab72f1c

File tree

2 files changed

+107
-22
lines changed

2 files changed

+107
-22
lines changed

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import concurrent.futures
22
import json
3+
import time
34

45
from datetime import datetime
56

@@ -57,67 +58,110 @@ def search(
5758
Returns:
5859
list[TextualMemoryItem]: List of matching memories.
5960
"""
61+
overall_start = time.perf_counter()
62+
logger.info(
63+
f"[SEARCH] 🚀 Starting search for query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}"
64+
)
65+
6066
if not info:
6167
logger.warning(
6268
"Please input 'info' when use tree.search so that "
6369
"the database would store the consume history."
6470
)
6571
info = {"user_id": "", "session_id": ""}
66-
# Step 1: Parse task structure into topic, concept, and fact levels
72+
else:
73+
logger.debug(f"[SEARCH] Received info dict: {info}")
74+
75+
# ===== Step 1: Parse task structure =====
76+
step_start = time.perf_counter()
6777
context = []
6878
if mode == "fine":
79+
logger.info("[SEARCH] Fine mode enabled, performing initial embedding search...")
80+
embed_start = time.perf_counter()
6981
query_embedding = self.embedder.embed([query])[0]
82+
logger.debug(f"[SEARCH] Query embedding vector length: {len(query_embedding)}")
83+
logger.info(
84+
f"[TIMER] Embedding query took {(time.perf_counter() - embed_start) * 1000:.2f} ms"
85+
)
86+
87+
search_start = time.perf_counter()
7088
related_node_ids = self.graph_store.search_by_embedding(query_embedding, top_k=top_k)
7189
related_nodes = [
7290
self.graph_store.get_node(related_node["id"]) for related_node in related_node_ids
7391
]
74-
7592
context = [related_node["memory"] for related_node in related_nodes]
7693
context = list(set(context))
94+
logger.info(f"[SEARCH] Found {len(related_nodes)} related nodes from graph_store.")
95+
logger.info(
96+
f"[TIMER] Graph embedding search took {(time.perf_counter() - search_start) * 1000:.2f} ms"
97+
)
7798

78-
# Step 1a: Parse task structure into topic, concept, and fact levels
99+
parse_start = time.perf_counter()
79100
parsed_goal = self.task_goal_parser.parse(
80101
task_description=query,
81102
context="\n".join(context),
82103
conversation=info.get("chat_history", []),
83104
mode=mode,
84105
)
85-
86-
query = (
87-
parsed_goal.rephrased_query
88-
if parsed_goal.rephrased_query and len(parsed_goal.rephrased_query) > 0
89-
else query
106+
logger.info(
107+
f"[TIMER] TaskGoalParser took {(time.perf_counter() - parse_start) * 1000:.2f} ms"
90108
)
109+
logger.info(f"TaskGoalParser result is {parsed_goal}")
91110

111+
query = parsed_goal.rephrased_query or query
92112
if parsed_goal.memories:
113+
embed_extra_start = time.perf_counter()
93114
query_embedding = self.embedder.embed(list({query, *parsed_goal.memories}))
115+
logger.info(
116+
f"[TIMER] Embedding parsed_goal memories took {(time.perf_counter() - embed_extra_start) * 1000:.2f} ms"
117+
)
118+
step_end = time.perf_counter()
119+
logger.info(f"[TIMER] Step 1 (Parsing & Embedding) took {(step_end - step_start):.2f} s")
120+
121+
# ===== Step 2: Define retrieval paths =====
122+
def timed(func):
123+
"""Decorator to measure and log time of retrieval steps."""
124+
125+
def wrapper(*args, **kwargs):
126+
start = time.perf_counter()
127+
result = func(*args, **kwargs)
128+
elapsed = time.perf_counter() - start
129+
logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s")
130+
return result
94131

95-
# Step 2a: Working memory retrieval (Path A)
132+
return wrapper
133+
134+
@timed
96135
def retrieve_from_working_memory():
97136
"""
98137
Direct structure-based retrieval from working memory.
99138
"""
139+
logger.info("[PATH-A] Retrieving from WorkingMemory...")
100140
if memory_type not in ["All", "WorkingMemory"]:
141+
logger.info("[PATH-A] Skipped (memory_type does not match)")
101142
return []
102-
103143
working_memory = self.graph_retriever.retrieve(
104144
query=query, parsed_goal=parsed_goal, top_k=top_k, memory_scope="WorkingMemory"
105145
)
146+
147+
logger.debug(f"[PATH-A] Retrieved {len(working_memory)} items.")
106148
# Rerank working_memory results
149+
rerank_start = time.perf_counter()
107150
ranked_memories = self.reranker.rerank(
108151
query=query,
109152
query_embedding=query_embedding[0],
110153
graph_results=working_memory,
111154
top_k=top_k,
112155
parsed_goal=parsed_goal,
113156
)
157+
logger.info(
158+
f"[TIMER] PATH-A rerank took {(time.perf_counter() - rerank_start) * 1000:.2f} ms"
159+
)
114160
return ranked_memories
115161

116-
# Step 2b: Parallel long-term and user memory retrieval (Path B)
162+
@timed
117163
def retrieve_ranked_long_term_and_user():
118-
"""
119-
Retrieve from both long-term and user memory, then rank and merge results.
120-
"""
164+
logger.info("[PATH-B] Retrieving from LongTermMemory & UserMemory...")
121165
long_term_items = (
122166
self.graph_retriever.retrieve(
123167
query=query,
@@ -140,7 +184,10 @@ def retrieve_ranked_long_term_and_user():
140184
if memory_type in ["All", "UserMemory"]
141185
else []
142186
)
143-
187+
logger.debug(
188+
f"[PATH-B] Retrieved {len(long_term_items)} LongTerm + {len(user_items)} UserMemory items."
189+
)
190+
rerank_start = time.perf_counter()
144191
# Rerank combined results
145192
ranked_memories = self.reranker.rerank(
146193
query=query,
@@ -149,21 +196,30 @@ def retrieve_ranked_long_term_and_user():
149196
top_k=top_k * 2,
150197
parsed_goal=parsed_goal,
151198
)
199+
logger.info(
200+
f"[TIMER] PATH-B rerank took {(time.perf_counter() - rerank_start) * 1000:.2f} ms"
201+
)
152202
return ranked_memories
153203

154-
# Step 2c: Internet retrieval (Path C)
204+
@timed
155205
def retrieve_from_internet():
156206
"""
157207
Retrieve information from the internet using Google Custom Search API.
158208
"""
209+
logger.info("[PATH-C] Retrieving from Internet...")
159210
if not self.internet_retriever or mode == "fast" or not parsed_goal.internet_search:
211+
logger.info(
212+
"[PATH-C] Skipped (no retriever, fast mode, or no internet_search flag)"
213+
)
160214
return []
161215
if memory_type not in ["All"]:
162216
return []
163217
internet_items = self.internet_retriever.retrieve_from_internet(
164218
query=query, top_k=top_k, parsed_goal=parsed_goal, info=info
165219
)
166220

221+
logger.debug(f"[PATH-C] Retrieved {len(internet_items)} internet items.")
222+
rerank_start = time.perf_counter()
167223
# Convert to the format expected by reranker
168224
ranked_memories = self.reranker.rerank(
169225
query=query,
@@ -172,9 +228,13 @@ def retrieve_from_internet():
172228
top_k=min(top_k, 5),
173229
parsed_goal=parsed_goal,
174230
)
231+
logger.info(
232+
f"[TIMER] PATH-C rerank took {(time.perf_counter() - rerank_start) * 1000:.2f} ms"
233+
)
175234
return ranked_memories
176235

177-
# Step 3: Parallel execution of all paths (enable internet search accoeding to parameter in the parsed goal)
236+
# ===== Step 3: Run retrieval in parallel =====
237+
path_start = time.perf_counter()
178238
if parsed_goal.internet_search:
179239
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
180240
future_working = executor.submit(retrieve_from_working_memory)
@@ -193,14 +253,24 @@ def retrieve_from_internet():
193253
working_results = future_working.result()
194254
hybrid_results = future_hybrid.result()
195255
searched_res = working_results + hybrid_results
256+
logger.info(
257+
f"[TIMER] Step 3 (Retrieval paths) took {(time.perf_counter() - path_start):.2f} s"
258+
)
259+
logger.info(f"[SEARCH] Total results before deduplication: {len(searched_res)}")
196260

197-
# Deduplicate by item.memory, keep higher score
261+
# ===== Step 4: Deduplication =====
262+
dedup_start = time.perf_counter()
198263
deduped_result = {}
199264
for item, score in searched_res:
200265
mem_key = item.memory
201266
if mem_key not in deduped_result or score > deduped_result[mem_key][1]:
202267
deduped_result[mem_key] = (item, score)
268+
logger.info(
269+
f"[TIMER] Deduplication took {(time.perf_counter() - dedup_start) * 1000:.2f} ms"
270+
)
203271

272+
# ===== Step 5: Sorting & trimming =====
273+
sort_start = time.perf_counter()
204274
searched_res = []
205275
for item, score in sorted(deduped_result.values(), key=lambda pair: pair[1], reverse=True)[
206276
:top_k
@@ -212,15 +282,18 @@ def retrieve_from_internet():
212282
searched_res.append(
213283
TextualMemoryItem(id=item.id, memory=item.memory, metadata=new_meta)
214284
)
285+
logger.info(
286+
f"[TIMER] Sorting & trimming took {(time.perf_counter() - sort_start) * 1000:.2f} ms"
287+
)
215288

216-
# Step 5: Update usage history with current timestamp
289+
# ===== Step 6: Update usage history =====
290+
usage_start = time.perf_counter()
217291
now_time = datetime.now().isoformat()
218292
if "chat_history" in info:
219293
info.pop("chat_history")
220294
usage_record = json.dumps(
221295
{"time": now_time, "info": info}
222296
) # `info` should be a serializable dict or string
223-
224297
for item in searched_res:
225298
if (
226299
hasattr(item, "id")
@@ -229,4 +302,11 @@ def retrieve_from_internet():
229302
):
230303
item.metadata.usage.append(usage_record)
231304
self.graph_store.update_node(item.id, {"usage": item.metadata.usage})
305+
logger.info(
306+
f"[TIMER] Usage history update took {(time.perf_counter() - usage_start) * 1000:.2f} ms"
307+
)
308+
309+
# ===== Finish =====
310+
logger.info(f"[SEARCH] ✅ Final top_k results: {len(searched_res)}")
311+
logger.info(f"[SEARCH] 🔚 Total search took {(time.perf_counter() - overall_start):.2f} s")
232312
return searched_res

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import logging
21
import traceback
32

43
from string import Template
54

65
from memos.llms.base import BaseLLM
6+
from memos.log import get_logger
77
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
88
from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT
99

1010

11+
logger = get_logger(__name__)
12+
13+
1114
class TaskGoalParser:
1215
"""
1316
Unified TaskGoalParser:
@@ -70,10 +73,12 @@ def _parse_fine(
7073
prompt = Template(TASK_PARSE_PROMPT).substitute(
7174
task=query.strip(), context=context, conversation=conversation_prompt
7275
)
76+
logger.info(f"Parsing Goal... LLM input is {prompt}")
7377
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
78+
logger.info(f"Parsing Goal... LLM Response is {response}")
7479
return self._parse_response(response)
7580
except Exception:
76-
logging.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}")
81+
logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}")
7782
return self._parse_fast(query)
7883

7984
def _parse_response(self, response: str) -> ParsedTaskGoal:

0 commit comments

Comments
 (0)