Skip to content

Commit 3106da1

Browse files
endxxxxfridayLmuzzlolCaralHsi
authored
Feat/rewrite query (#149)
* fix: align MOSProduct._build_system_prompt signature with MOSCore - Fix TypeError when calling MOSProduct.chat() method - MOSCore.chat() expects _build_system_prompt(memories, base_prompt=...) - MOSProduct._build_system_prompt had incompatible signature (user_id, memories) - Updated signature to match parent class interface - Removed unused user_id parameter from method * feat: rewrite query before searching memories * fix: JsonDecodeError * fix: fix prompt * ran format * feat: rewrite and reinforce query * feat: rewrite and reinforce query * feat: make internet search optional * fix: disable rewrite query seperately * feat: finish rewrite query * fix: avoid senseitive content in the memories * fix: test bug in tree_task_goal_parser * fix: info is None bug * fix: bug in nebula get_by_metadata * fix: add info to memos_core unittest --------- Co-authored-by: chunyu li <[email protected]> Co-authored-by: muzzlol <[email protected]> Co-authored-by: CaralHsi <[email protected]>
1 parent bd450bd commit 3106da1

File tree

13 files changed

+227
-57
lines changed

13 files changed

+227
-57
lines changed

examples/core_memories/tree_textual_memory.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,27 +193,45 @@ def embed_memory_item(memory: str) -> list[float]:
193193

194194
time.sleep(60)
195195

196+
init_time = time.time()
196197
results = my_tree_textual_memory.search(
197198
"Talk about the user's childhood story?",
198199
top_k=10,
199-
info={"query": "Talk about the user's childhood story?", "user_id": "111", "session": "2234"},
200+
info={
201+
"query": "Talk about the user's childhood story?",
202+
"user_id": "111",
203+
"session_id": "2234",
204+
"chat_history": [{"role": "user", "content": "xxxxx"}],
205+
},
200206
)
201207
for i, r in enumerate(results):
202208
r = r.to_dict()
203209
print(f"{i}'th similar result is: " + str(r["memory"]))
204-
print(f"Successfully search {len(results)} memories")
210+
print(f"Successfully search {len(results)} memories in {round(time.time() - init_time)}s")
205211

206212
# try this when use 'fine' mode (Note that you should pass the internet Config, refer to examples/core_memories/textual_internet_memoy.py)
213+
init_time = time.time()
207214
results_fine_search = my_tree_textual_memory.search(
208-
"Recent news in NewYork",
215+
"Recent news in the first city you've mentioned.",
209216
top_k=10,
210217
mode="fine",
211-
info={"query": "Recent news in NewYork", "user_id": "111", "session": "2234"},
218+
info={
219+
"query": "Recent news in NewYork",
220+
"user_id": "111",
221+
"session_id": "2234",
222+
"chat_history": [
223+
{"role": "user", "content": "I want to know three beautiful cities"},
224+
{"role": "assistant", "content": "New York, London, and Shanghai"},
225+
],
226+
},
212227
)
228+
213229
for i, r in enumerate(results_fine_search):
214230
r = r.to_dict()
215231
print(f"{i}'th similar result is: " + str(r["memory"]))
216-
print(f"Successfully search {len(results_fine_search)} memories")
232+
print(
233+
f"Successfully search {len(results_fine_search)} memories in {round(time.time() - init_time)}s"
234+
)
217235

218236
# find related nodes
219237
related_nodes = my_tree_textual_memory.get_relevant_subgraph("Painting")

src/memos/graph_dbs/nebular.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,8 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
830830

831831
def _escape_value(value):
832832
if isinstance(value, str):
833-
return f'"{value}"'
833+
escaped = value.replace('"', '\\"')
834+
return f'"{escaped}"'
834835
elif isinstance(value, list):
835836
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
836837
else:

src/memos/mem_os/core.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from memos.memories.activation.item import ActivationMemoryItem
2626
from memos.memories.parametric.item import ParametricMemoryItem
2727
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
28+
from memos.templates.mos_prompts import QUERY_REWRITING_PROMPT
2829
from memos.types import ChatHistory, MessageList, MOSSearchResult
2930

3031

@@ -283,7 +284,15 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
283284
)
284285
self.mem_scheduler.submit_messages(messages=[message_item])
285286

286-
memories = mem_cube.text_mem.search(query, top_k=self.config.top_k)
287+
memories = mem_cube.text_mem.search(
288+
query,
289+
top_k=self.config.top_k,
290+
info={
291+
"user_id": target_user_id,
292+
"session_id": self.session_id,
293+
"chat_history": chat_history,
294+
},
295+
)
287296
memories_all.extend(memories)
288297
logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
289298
system_prompt = self._build_system_prompt(memories_all, base_prompt=base_prompt)
@@ -556,6 +565,9 @@ def search(
556565
logger.info(
557566
f"User {target_user_id} has access to {len(user_cube_ids)} cubes: {user_cube_ids}"
558567
)
568+
569+
chat_history = self.chat_history_manager[target_user_id]
570+
559571
result: MOSSearchResult = {
560572
"text_mem": [],
561573
"act_mem": [],
@@ -575,7 +587,11 @@ def search(
575587
top_k=top_k if top_k else self.config.top_k,
576588
mode=mode,
577589
manual_close_internet=not internet_search,
578-
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
590+
info={
591+
"user_id": target_user_id,
592+
"session_id": self.session_id,
593+
"chat_history": chat_history,
594+
},
579595
)
580596
result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
581597
logger.info(
@@ -645,7 +661,7 @@ def add(
645661
memories = self.mem_reader.get_memory(
646662
messages_list,
647663
type="chat",
648-
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
664+
info={"user_id": target_user_id, "session_id": self.session_id},
649665
)
650666

651667
mem_ids = []
@@ -689,7 +705,7 @@ def add(
689705
memories = self.mem_reader.get_memory(
690706
messages_list,
691707
type="chat",
692-
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
708+
info={"user_id": target_user_id, "session_id": self.session_id},
693709
)
694710

695711
mem_ids = []
@@ -723,7 +739,7 @@ def add(
723739
doc_memories = self.mem_reader.get_memory(
724740
documents,
725741
type="doc",
726-
info={"user_id": target_user_id, "session_id": str(uuid.uuid4())},
742+
info={"user_id": target_user_id, "session_id": self.session_id},
727743
)
728744

729745
mem_ids = []
@@ -977,3 +993,27 @@ def share_cube_with_user(self, cube_id: str, target_user_id: str) -> bool:
977993
raise ValueError(f"Target user '{target_user_id}' does not exist or is inactive.")
978994

979995
return self.user_manager.add_user_to_cube(target_user_id, cube_id)
996+
997+
def get_query_rewrite(self, query: str, user_id: str | None = None):
998+
"""
999+
Rewrite user's query according the context.
1000+
Args:
1001+
query (str): The search query that needs rewriting.
1002+
user_id(str, optional): The identifier of the user that the query belongs to.
1003+
If None, the default user is used.
1004+
1005+
Returns:
1006+
str: query after rewriting process.
1007+
"""
1008+
target_user_id = user_id if user_id is not None else self.user_id
1009+
chat_history = self.chat_history_manager[target_user_id]
1010+
1011+
dialogue = "————{}".format("\n————".join(chat_history.chat_history))
1012+
user_prompt = QUERY_REWRITING_PROMPT.format(dialogue=dialogue, query=query)
1013+
messages = {"role": "user", "content": user_prompt}
1014+
rewritten_result = self.chat_llm.generate(messages=messages)
1015+
rewritten_result = json.loads(rewritten_result)
1016+
if rewritten_result.get("former_dialogue_related", False):
1017+
rewritten_query = rewritten_result.get("rewritten_question")
1018+
return rewritten_query if len(rewritten_query) > 0 else query
1019+
return query

src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def retrieve_from_internet(
141141
Returns:
142142
List of TextualMemoryItem
143143
"""
144+
if not info:
145+
info = {"user_id": "", "session_id": ""}
144146
# Get search results
145147
search_results = self.google_api.get_all_results(query, max_results=top_k)
146148

src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ class ParsedTaskGoal:
1010
memories: list[str] = field(default_factory=list)
1111
keys: list[str] = field(default_factory=list)
1212
tags: list[str] = field(default_factory=list)
13+
rephrased_query: str | None = None
14+
internet_search: bool = False
1315
goal_type: str | None = None # e.g., 'default', 'explanation', etc.

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from memos.embedders.factory import OllamaEmbedder
77
from memos.graph_dbs.factory import Neo4jGraphDB
88
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
9+
from memos.log import get_logger
910
from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem
1011

1112
from .internet_retriever_factory import InternetRetrieverFactory
@@ -15,6 +16,9 @@
1516
from .task_goal_parser import TaskGoalParser
1617

1718

19+
logger = get_logger(__name__)
20+
21+
1822
class Searcher:
1923
def __init__(
2024
self,
@@ -53,7 +57,12 @@ def search(
5357
Returns:
5458
list[TextualMemoryItem]: List of matching memories.
5559
"""
56-
60+
if not info:
61+
logger.warning(
62+
"Please input 'info' when use tree.search so that "
63+
"the database would store the consume history."
64+
)
65+
info = {"user_id": "", "session_id": ""}
5766
# Step 1: Parse task structure into topic, concept, and fact levels
5867
context = []
5968
if mode == "fine":
@@ -67,7 +76,18 @@ def search(
6776
context = list(set(context))
6877

6978
# Step 1a: Parse task structure into topic, concept, and fact levels
70-
parsed_goal = self.task_goal_parser.parse(query, "\n".join(context))
79+
parsed_goal = self.task_goal_parser.parse(
80+
task_description=query,
81+
context="\n".join(context),
82+
conversation=info.get("chat_history", []),
83+
mode=mode,
84+
)
85+
86+
query = (
87+
parsed_goal.rephrased_query
88+
if parsed_goal.rephrased_query and len(parsed_goal.rephrased_query) > 0
89+
else query
90+
)
7191

7292
if parsed_goal.memories:
7393
query_embedding = self.embedder.embed(list({query, *parsed_goal.memories}))
@@ -136,7 +156,7 @@ def retrieve_from_internet():
136156
"""
137157
Retrieve information from the internet using Google Custom Search API.
138158
"""
139-
if not self.internet_retriever or mode == "fast":
159+
if not self.internet_retriever or mode == "fast" or not parsed_goal.internet_search:
140160
return []
141161
if memory_type not in ["All"]:
142162
return []
@@ -154,16 +174,25 @@ def retrieve_from_internet():
154174
)
155175
return ranked_memories
156176

157-
# Step 3: Parallel execution of all paths
158-
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
159-
future_working = executor.submit(retrieve_from_working_memory)
160-
future_hybrid = executor.submit(retrieve_ranked_long_term_and_user)
161-
future_internet = executor.submit(retrieve_from_internet)
162-
163-
working_results = future_working.result()
164-
hybrid_results = future_hybrid.result()
165-
internet_results = future_internet.result()
166-
searched_res = working_results + hybrid_results + internet_results
177+
# Step 3: Parallel execution of all paths (enable internet search accoeding to parameter in the parsed goal)
178+
if parsed_goal.internet_search:
179+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
180+
future_working = executor.submit(retrieve_from_working_memory)
181+
future_hybrid = executor.submit(retrieve_ranked_long_term_and_user)
182+
future_internet = executor.submit(retrieve_from_internet)
183+
184+
working_results = future_working.result()
185+
hybrid_results = future_hybrid.result()
186+
internet_results = future_internet.result()
187+
searched_res = working_results + hybrid_results + internet_results
188+
else:
189+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
190+
future_working = executor.submit(retrieve_from_working_memory)
191+
future_hybrid = executor.submit(retrieve_ranked_long_term_and_user)
192+
193+
working_results = future_working.result()
194+
hybrid_results = future_hybrid.result()
195+
searched_res = working_results + hybrid_results
167196

168197
# Deduplicate by item.memory, keep higher score
169198
deduped_result = {}

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import json
1+
import logging
2+
import traceback
23

34
from string import Template
45

@@ -14,54 +15,80 @@ class TaskGoalParser:
1415
- mode == 'fine': use LLM to parse structured topic/keys/tags
1516
"""
1617

17-
def __init__(self, llm=BaseLLM, mode: str = "fast"):
18+
def __init__(self, llm=BaseLLM):
1819
self.llm = llm
19-
self.mode = mode
2020

21-
def parse(self, task_description: str, context: str = "") -> ParsedTaskGoal:
21+
def parse(
22+
self,
23+
task_description: str,
24+
context: str = "",
25+
conversation: list[dict] | None = None,
26+
mode: str = "fast",
27+
) -> ParsedTaskGoal:
2228
"""
2329
Parse user input into structured semantic layers.
2430
Returns:
2531
ParsedTaskGoal: object containing topic/concept/fact levels and optional metadata
2632
- mode == 'fast': use jieba to split words only
2733
- mode == 'fine': use LLM to parse structured topic/keys/tags
2834
"""
29-
if self.mode == "fast":
35+
if mode == "fast":
3036
return self._parse_fast(task_description)
31-
elif self.mode == "fine":
37+
elif mode == "fine":
3238
if not self.llm:
3339
raise ValueError("LLM not provided for slow mode.")
34-
return self._parse_fine(task_description, context)
40+
return self._parse_fine(task_description, context, conversation)
3541
else:
36-
raise ValueError(f"Unknown mode: {self.mode}")
42+
raise ValueError(f"Unknown mode: {mode}")
3743

3844
def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal:
3945
"""
4046
Fast mode: simple jieba word split.
4147
"""
4248
return ParsedTaskGoal(
43-
memories=[task_description], keys=[task_description], tags=[], goal_type="default"
49+
memories=[task_description],
50+
keys=[task_description],
51+
tags=[],
52+
goal_type="default",
53+
rephrased_query=task_description,
54+
internet_search=False,
4455
)
4556

46-
def _parse_fine(self, query: str, context: str = "") -> ParsedTaskGoal:
57+
def _parse_fine(
58+
self, query: str, context: str = "", conversation: list[dict] | None = None
59+
) -> ParsedTaskGoal:
4760
"""
4861
Slow mode: LLM structured parse.
4962
"""
50-
prompt = Template(TASK_PARSE_PROMPT).substitute(task=query.strip(), context=context)
51-
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
52-
return self._parse_response(response)
63+
try:
64+
if conversation:
65+
conversation_prompt = "\n".join(
66+
[f"{each['role']}: {each['content']}" for each in conversation]
67+
)
68+
else:
69+
conversation_prompt = ""
70+
prompt = Template(TASK_PARSE_PROMPT).substitute(
71+
task=query.strip(), context=context, conversation=conversation_prompt
72+
)
73+
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
74+
return self._parse_response(response)
75+
except Exception:
76+
logging.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}")
77+
return self._parse_fast(query)
5378

5479
def _parse_response(self, response: str) -> ParsedTaskGoal:
5580
"""
5681
Parse LLM JSON output safely.
5782
"""
5883
try:
59-
response = response.replace("```", "").replace("json", "")
60-
response_json = json.loads(response.strip())
84+
response = response.replace("```", "").replace("json", "").strip()
85+
response_json = eval(response)
6186
return ParsedTaskGoal(
6287
memories=response_json.get("memories", []),
6388
keys=response_json.get("keys", []),
6489
tags=response_json.get("tags", []),
90+
rephrased_query=response_json.get("rephrased_instruction", None),
91+
internet_search=response_json.get("internet_search", False),
6592
goal_type=response_json.get("goal_type", "default"),
6693
)
6794
except Exception as e:

0 commit comments

Comments
 (0)