Skip to content

Commit 25c7642

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/pref optimize update (#422)
* add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit --------- Co-authored-by: yuan.wang <[email protected]>
1 parent 81c7ad9 commit 25c7642

File tree

9 files changed

+36
-23
lines changed

9 files changed

+36
-23
lines changed

evaluation/scripts/PrefEval/pref_memos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def search_memory_for_line(line_data, mem_client, top_k_value):
103103
f"- {entry.get('memory', '')}"
104104
for entry in relevant_memories["text_mem"][0]["memories"]
105105
)
106-
+ f"\n{relevant_memories['pref_string']}"
106+
+ f"\n{relevant_memories.get('pref_string', '')}"
107107
)
108108

109109
memory_tokens_used = len(tokenizer.encode(memories_str))

evaluation/scripts/locomo/locomo_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ def memos_api_search(
107107

108108
speaker_a_context = (
109109
"\n".join([i["memory"] for i in search_a_results["text_mem"][0]["memories"]])
110-
+ f"\n{search_a_results['pref_string']}"
110+
+ f"\n{search_a_results.get('pref_string', '')}"
111111
)
112112
speaker_b_context = (
113113
"\n".join([i["memory"] for i in search_b_results["text_mem"][0]["memories"]])
114-
+ f"\n{search_b_results['pref_string']}"
114+
+ f"\n{search_b_results.get('pref_string', '')}"
115115
)
116116

117117
context = TEMPLATE_MEMOS.format(

evaluation/scripts/longmemeval/lme_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def memos_search(client, query, user_id, top_k):
4646
results = client.search(query=query, user_id=user_id, top_k=top_k)
4747
context = (
4848
"\n".join([i["memory"] for i in results["text_mem"][0]["memories"]])
49-
+ f"\n{results['pref_string']}"
49+
+ f"\n{results.get('pref_string', '')}"
5050
)
5151
context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context)
5252
duration_ms = (time() - start) * 1000

evaluation/scripts/personamem/pm_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def memos_search(client, user_id, query, top_k):
8181
start = time()
8282
results = client.search(query=query, user_id=user_id, top_k=top_k)
8383
search_memories = (
84-
"\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"])
85-
+ f"\n{results['pref_string']}"
84+
"\n".join(item["memory"] for cube in results["text_mem"] for item in cube["memories"])
85+
+ f"\n{results.get('pref_string', '')}"
8686
)
8787
context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories)
8888

evaluation/scripts/utils/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def search(self, query, user_id, top_k):
182182
"conversation_id": "",
183183
"top_k": top_k,
184184
"mode": os.getenv("SEARCH_MODE", "fast"),
185-
"handle_pref_mem": False,
185+
"include_preference": True,
186+
"pref_top_k": 6,
186187
},
187188
ensure_ascii=False,
188189
)
@@ -344,9 +345,10 @@ def wait_for_completion(self, task_id):
344345
query = "杭州西湖有什么"
345346
top_k = 5
346347

347-
# MEMOBASE
348-
client = MemobaseClient()
348+
# MEMOS-API
349+
client = MemosApiClient()
349350
for m in messages:
350351
m["created_at"] = iso_date
351-
client.add(messages, user_id)
352+
client.add(messages, user_id, user_id)
352353
memories = client.search(query, user_id, top_k)
354+
print(memories)

src/memos/api/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def get_preference_memory_config() -> dict[str, Any]:
354354
return {
355355
"backend": "pref_text",
356356
"config": {
357-
"extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()},
357+
"extractor_llm": APIConfig.get_memreader_config(),
358358
"vector_db": {
359359
"backend": "milvus",
360360
"config": APIConfig.get_milvus_config(),

src/memos/api/product_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ class APISearchRequest(BaseRequest):
180180
operation: list[PermissionDict] | None = Field(
181181
None, description="operation ids for multi cubes"
182182
)
183-
handle_pref_mem: bool = Field(False, description="Whether to handle preference memory")
183+
include_preference: bool = Field(True, description="Whether to handle preference memory")
184+
pref_top_k: int = Field(6, description="Number of preference results to return")
184185

185186

186187
class APIADDRequest(BaseRequest):

src/memos/api/routers/server_router.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,18 @@ def _post_process_pref_mem(
324324
memories_result: list[dict[str, Any]],
325325
pref_formatted_mem: list[dict[str, Any]],
326326
mem_cube_id: str,
327-
handle_pref_mem: bool,
327+
include_preference: bool,
328328
):
329-
if handle_pref_mem:
329+
if include_preference:
330330
memories_result["pref_mem"].append(
331331
{
332332
"cube_id": mem_cube_id,
333333
"memories": pref_formatted_mem,
334334
}
335335
)
336-
pref_instruction: str = instruct_completion(pref_formatted_mem)
336+
pref_instruction, pref_note = instruct_completion(pref_formatted_mem)
337337
memories_result["pref_string"] = pref_instruction
338+
memories_result["pref_note"] = pref_note
338339

339340
return memories_result
340341

@@ -354,7 +355,7 @@ def search_memories(search_req: APISearchRequest):
354355
"act_mem": [],
355356
"para_mem": [],
356357
"pref_mem": [],
357-
"pref_string": "",
358+
"pref_note": "",
358359
}
359360

360361
search_mode = search_req.mode
@@ -382,7 +383,7 @@ def _search_pref():
382383
return []
383384
results = naive_mem_cube.pref_mem.search(
384385
query=search_req.query,
385-
top_k=search_req.top_k,
386+
top_k=search_req.pref_top_k,
386387
info={
387388
"user_id": search_req.user_id,
388389
"session_id": search_req.session_id,
@@ -405,7 +406,10 @@ def _search_pref():
405406
)
406407

407408
memories_result = _post_process_pref_mem(
408-
memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem
409+
memories_result,
410+
pref_formatted_memories,
411+
search_req.mem_cube_id,
412+
search_req.include_preference,
409413
)
410414

411415
return SearchResponse(

src/memos/templates/instruction_completion.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def instruct_completion(
88
memories: list[dict[str, Any]] | None = None,
9-
) -> str:
9+
) -> [str, str]:
1010
"""Create instruction following the preferences."""
1111
explicit_pref = []
1212
implicit_pref = []
@@ -49,10 +49,16 @@ def instruct_completion(
4949
lang = detect_lang(explicit_pref_str + implicit_pref_str)
5050

5151
if not explicit_pref_str and not implicit_pref_str:
52-
return ""
52+
return "", ""
5353
if not explicit_pref_str:
54-
return implicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_exp_map[lang], "")
54+
pref_note = _prompt_map[lang].replace(_remove_exp_map[lang], "")
55+
pref_string = implicit_pref_str + "\n" + pref_note
56+
return pref_string, pref_note
5557
if not implicit_pref_str:
56-
return explicit_pref_str + "\n" + _prompt_map[lang].replace(_remove_imp_map[lang], "")
58+
pref_note = _prompt_map[lang].replace(_remove_imp_map[lang], "")
59+
pref_string = explicit_pref_str + "\n" + pref_note
60+
return pref_string, pref_note
5761

58-
return explicit_pref_str + "\n" + implicit_pref_str + "\n" + _prompt_map[lang]
62+
pref_note = _prompt_map[lang]
63+
pref_string = explicit_pref_str + "\n" + implicit_pref_str + "\n" + pref_note
64+
return pref_string, pref_note

0 commit comments

Comments
 (0)