Skip to content

Commit 094c3ca

Browse files
committed
feat: update zh en and mem
1 parent afb1bdb commit 094c3ca

File tree

4 files changed

+68
-21
lines changed

4 files changed

+68
-21
lines changed

src/memos/api/product_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,10 @@ class SearchRequest(BaseRequest):
150150
user_id: str = Field(..., description="User ID")
151151
query: str = Field(..., description="Search query")
152152
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
153+
154+
155+
class SuggestionRequest(BaseRequest):
156+
"""Request model for getting suggestion queries."""
157+
158+
user_id: str = Field(..., description="User ID")
159+
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")

src/memos/api/routers/product_router.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SearchRequest,
1616
SearchResponse,
1717
SimpleResponse,
18+
SuggestionRequest,
1819
SuggestionResponse,
1920
UserRegisterRequest,
2021
UserRegisterResponse,
@@ -86,7 +87,6 @@ async def register_user(user_req: UserRegisterRequest):
8687
logger.error(f"Failed to register user: {traceback.format_exc()}")
8788
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
8889

89-
9090
@router.get(
9191
"/suggestions/{user_id}", summary="Get suggestion queries", response_model=SuggestionResponse
9292
)
@@ -105,6 +105,25 @@ async def get_suggestion_queries(user_id: str):
105105
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
106106

107107

108+
@router.post("/suggestions", summary="Get suggestion queries with language", response_model=SuggestionResponse)
109+
async def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
110+
"""Get suggestion queries for a specific user with language preference."""
111+
try:
112+
mos_product = get_mos_product_instance()
113+
suggestions = mos_product.get_suggestion_query(
114+
user_id=suggestion_req.user_id,
115+
language=suggestion_req.language
116+
)
117+
return SuggestionResponse(
118+
message="Suggestions retrieved successfully", data={"query": suggestions}
119+
)
120+
except ValueError as err:
121+
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
122+
except Exception as err:
123+
logger.error(f"Failed to get suggestions: {traceback.format_exc()}")
124+
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
125+
126+
108127
@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
109128
async def get_all_memories(memory_req: GetMemoryRequest):
110129
"""Get all memories for a specific user."""

src/memos/mem_os/product.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -579,29 +579,46 @@ def user_register(
579579
except Exception as e:
580580
return {"status": "error", "message": f"Failed to register user: {e!s}"}
581581

582-
def get_suggestion_query(self, user_id: str) -> list[str]:
582+
def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]:
583583
"""Get suggestion query from LLM.
584584
Args:
585585
user_id (str): User ID.
586+
language (str): Language for suggestions ("zh" or "en").
586587
587588
Returns:
588589
list[str]: The suggestion query list.
589590
"""
590591

591-
suggestion_prompt = """
592-
You are a helpful assistant that can help users to generate suggestion query
593-
I will get some user recently memories,
594-
you should generate some suggestion query , the query should be user what to query,
595-
user recently memories is :
596-
{memories}
597-
please generate 3 suggestion query,
598-
output should be a json format, the key is "query", the value is a list of suggestion query.
599-
600-
example:
601-
{{
602-
"query": ["query1", "query2", "query3"]
603-
}}
604-
"""
592+
if language == "zh":
593+
suggestion_prompt = """
594+
你是一个有用的助手,可以帮助用户生成建议查询。
595+
我将获取用户最近的一些记忆,
596+
你应该生成一些建议查询,这些查询应该是用户想要查询的内容,
597+
用户最近的记忆是:
598+
{memories}
599+
请生成3个建议查询用中文,
600+
输出应该是json格式,键是"query",值是一个建议查询列表。
601+
602+
示例:
603+
{{
604+
"query": ["查询1", "查询2", "查询3"]
605+
}}
606+
"""
607+
else: # English
608+
suggestion_prompt = """
609+
You are a helpful assistant that can help users to generate suggestion query.
610+
I will get some user recently memories,
611+
you should generate some suggestion query, the query should be user what to query,
612+
user recently memories is:
613+
{memories}
614+
please generate 3 suggestion query in English,
615+
output should be a json format, the key is "query", the value is a list of suggestion query.
616+
617+
example:
618+
{{
619+
"query": ["query1", "query2", "query3"]
620+
}}
621+
"""
605622
text_mem_result = super().search("my recently memories", user_id=user_id, top_k=10)["text_mem"]
606623
if text_mem_result:
607624
memories = "\n".join(
@@ -842,7 +859,7 @@ def get_all(
842859
"LongTermMemory": 0.40,
843860
"UserMemory": 0.40,
844861
}
845-
tree_result = convert_graph_to_tree_forworkmem(
862+
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
846863
memories, target_node_count=150, type_ratios=custom_type_ratios
847864
)
848865
memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
@@ -851,7 +868,7 @@ def get_all(
851868
tree_result["children"] = children_sort
852869
memories_filtered["tree_structure"] = tree_result
853870
reformat_memory_list.append(
854-
{"cube_id": memory["cube_id"], "memories": [memories_filtered]}
871+
{"cube_id": memory["cube_id"], "memories": [memories_filtered], "memory_statistics": node_type_count}
855872
)
856873
elif memory_type == "act_mem":
857874
reformat_memory_list.append(
@@ -915,7 +932,7 @@ def get_subgraph(
915932
for memory in memory_list:
916933
memories = remove_embedding_recursive(memory["memories"])
917934
custom_type_ratios = {"WorkingMemory": 0.20, "LongTermMemory": 0.40, "UserMemory": 0.4}
918-
tree_result = convert_graph_to_tree_forworkmem(
935+
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
919936
memories, target_node_count=150, type_ratios=custom_type_ratios
920937
)
921938
memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
@@ -924,7 +941,7 @@ def get_subgraph(
924941
tree_result["children"] = children_sort
925942
memories_filtered["tree_structure"] = tree_result
926943
reformat_memory_list.append(
927-
{"cube_id": memory["cube_id"], "memories": [memories_filtered]}
944+
{"cube_id": memory["cube_id"], "memories": [memories_filtered], "memory_statistics": node_type_count}
928945
)
929946

930947
return reformat_memory_list

src/memos/mem_os/utils/format_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ def convert_graph_to_tree_forworkmem(
508508
for original_edge in original_edges:
509509
if original_edge["type"] == "PARENT":
510510
filter_original_edges.append(original_edge)
511+
node_type_count = {}
512+
for node in original_nodes:
513+
node_type = node.get("metadata", {}).get("memory_type", "Unknown")
514+
node_type_count[node_type] = node_type_count.get(node_type, 0) + 1
511515
original_edges = filter_original_edges
512516
# Use enhanced type-balanced sampling
513517
if len(original_nodes) > target_node_count:
@@ -625,7 +629,7 @@ def build_tree(node_id: str) -> dict[str, Any]:
625629
"frequency": 0,
626630
}
627631

628-
return result
632+
return result, node_type_count
629633

630634

631635
def print_tree_structure(node: dict[str, Any], level: int = 0, max_level: int = 5):

0 commit comments

Comments
 (0)