Skip to content

Commit e2d34d8

Browse files
authored
feat: reorganizer config code and add remove dup nodes for playground-demo (#135)
* feat: update config * fix:dim * change dim * fix:change default db * fix:delay * fix:len * fix:change recently mem size * fix:dup node error * fix: remove mock_data * fix: change config * feat: reorganize code * add: add json parse for en * fix:change user_id * fix: logger info
1 parent 727dc18 commit e2d34d8

File tree

4 files changed

+308
-142
lines changed

4 files changed

+308
-142
lines changed

src/memos/api/config.py

Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ def get_start_default_config() -> dict[str, Any]:
300300
def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]:
301301
"""Create configuration for a specific user."""
302302
openai_config = APIConfig.get_openai_config()
303-
304303
qwen_config = APIConfig.qwen_config()
305304
vllm_config = APIConfig.vllm_config()
306305
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
@@ -351,8 +350,15 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
351350

352351
default_config = MOSConfig(**config_dict)
353352

354-
if os.getenv("NEO4J_BACKEND", "neo4j_community").lower() == "neo4j_community":
355-
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
353+
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
354+
neo4j_config = APIConfig.get_neo4j_config(user_id)
355+
356+
graph_db_backend_map = {
357+
"neo4j-community": neo4j_community_config,
358+
"neo4j": neo4j_config,
359+
}
360+
graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
361+
if graph_db_backend in graph_db_backend_map:
356362
# Create MemCube config
357363
default_cube_config = GeneralMemCubeConfig.model_validate(
358364
{
@@ -364,8 +370,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
364370
"extractor_llm": {"backend": "openai", "config": openai_config},
365371
"dispatcher_llm": {"backend": "openai", "config": openai_config},
366372
"graph_db": {
367-
"backend": "neo4j-community",
368-
"config": neo4j_community_config,
373+
"backend": graph_db_backend,
374+
"config": graph_db_backend_map[graph_db_backend],
369375
},
370376
"embedder": APIConfig.get_embedder_config(),
371377
},
@@ -377,30 +383,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
377383
}
378384
)
379385
else:
380-
neo4j_config = APIConfig.get_neo4j_config(user_id)
381-
# Create MemCube config
382-
default_cube_config = GeneralMemCubeConfig.model_validate(
383-
{
384-
"user_id": user_id,
385-
"cube_id": f"{user_name}_default_cube",
386-
"text_mem": {
387-
"backend": "tree_text",
388-
"config": {
389-
"extractor_llm": {"backend": "openai", "config": openai_config},
390-
"dispatcher_llm": {"backend": "openai", "config": openai_config},
391-
"graph_db": {
392-
"backend": "neo4j",
393-
"config": neo4j_config,
394-
},
395-
"embedder": APIConfig.get_embedder_config(),
396-
},
397-
},
398-
"act_mem": {}
399-
if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false"
400-
else APIConfig.get_activation_vllm_config(),
401-
"para_mem": {},
402-
}
403-
)
386+
raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
404387

405388
default_mem_cube = GeneralMemCube(default_cube_config)
406389
return default_config, default_mem_cube
@@ -416,9 +399,14 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
416399
return None
417400

418401
openai_config = APIConfig.get_openai_config()
419-
420-
if os.getenv("NEO4J_BACKEND", "neo4j_community").lower() == "neo4j_community":
421-
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
402+
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
403+
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
404+
graph_db_backend_map = {
405+
"neo4j-community": neo4j_community_config,
406+
"neo4j": neo4j_config,
407+
}
408+
graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
409+
if graph_db_backend in graph_db_backend_map:
422410
return GeneralMemCubeConfig.model_validate(
423411
{
424412
"user_id": "default",
@@ -429,8 +417,8 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
429417
"extractor_llm": {"backend": "openai", "config": openai_config},
430418
"dispatcher_llm": {"backend": "openai", "config": openai_config},
431419
"graph_db": {
432-
"backend": "neo4j-community",
433-
"config": neo4j_community_config,
420+
"backend": graph_db_backend,
421+
"config": graph_db_backend_map[graph_db_backend],
434422
},
435423
"embedder": APIConfig.get_embedder_config(),
436424
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
@@ -444,28 +432,4 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
444432
}
445433
)
446434
else:
447-
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
448-
return GeneralMemCubeConfig.model_validate(
449-
{
450-
"user_id": "default",
451-
"cube_id": "default_cube",
452-
"text_mem": {
453-
"backend": "tree_text",
454-
"config": {
455-
"extractor_llm": {"backend": "openai", "config": openai_config},
456-
"dispatcher_llm": {"backend": "openai", "config": openai_config},
457-
"graph_db": {
458-
"backend": "neo4j",
459-
"config": neo4j_config,
460-
},
461-
"embedder": APIConfig.get_embedder_config(),
462-
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
463-
== "true",
464-
},
465-
},
466-
"act_mem": {}
467-
if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false"
468-
else APIConfig.get_activation_vllm_config(),
469-
"para_mem": {},
470-
}
471-
)
435+
raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")

src/memos/mem_os/core.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -646,14 +646,8 @@ def add(
646646
)
647647
else:
648648
messages_list = [
649-
[
650-
{"role": "user", "content": memory_content},
651-
{
652-
"role": "assistant",
653-
"content": "",
654-
}, # add by str to keep the format,assistant role is empty
655-
]
656-
]
649+
[{"role": "user", "content": memory_content}]
650+
] # for only user-str input and convert message
657651
memories = self.mem_reader.get_memory(
658652
messages_list,
659653
type="chat",

src/memos/mem_os/product.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from memos.mem_cube.general import GeneralMemCube
1717
from memos.mem_os.core import MOSCore
1818
from memos.mem_os.utils.format_utils import (
19+
clean_json_response,
1920
convert_graph_to_tree_forworkmem,
21+
ensure_unique_tree_ids,
2022
filter_nodes_by_tree_ids,
2123
remove_embedding_recursive,
2224
sort_children_by_memory_type,
@@ -656,15 +658,15 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]:
656658
you should generate some suggestion query, the query should be user what to query,
657659
user recently memories is:
658660
{memories}
659-
please generate 3 suggestion query in English,
661+
if the user recently memories is empty, please generate 3 suggestion query in English,
660662
output should be a json format, the key is "query", the value is a list of suggestion query.
661663
662664
example:
663665
{{
664666
"query": ["query1", "query2", "query3"]
665667
}}
666668
"""
667-
text_mem_result = super().search("my recently memories", user_id=user_id, top_k=10)[
669+
text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[
668670
"text_mem"
669671
]
670672
if text_mem_result:
@@ -673,8 +675,8 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]:
673675
memories = ""
674676
message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}]
675677
response = self.chat_llm.generate(message_list)
676-
response_json = json.loads(response)
677-
678+
clean_response = clean_json_response(response)
679+
response_json = json.loads(clean_response)
678680
return response_json["query"]
679681

680682
def chat(
@@ -762,11 +764,10 @@ def chat_with_references(
762764
system_prompt = self._build_system_prompt(user_id, memories_list)
763765

764766
# Get chat history
765-
target_user_id = user_id if user_id is not None else self.user_id
766-
if target_user_id not in self.chat_history_manager:
767-
self._register_chat_history(target_user_id)
767+
if user_id not in self.chat_history_manager:
768+
self._register_chat_history(user_id)
768769

769-
chat_history = self.chat_history_manager[target_user_id]
770+
chat_history = self.chat_history_manager[user_id]
770771
current_messages = [
771772
{"role": "system", "content": system_prompt},
772773
*chat_history.chat_history,
@@ -918,8 +919,10 @@ def get_all(
918919
"UserMemory": 0.40,
919920
}
920921
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
921-
memories, target_node_count=150, type_ratios=custom_type_ratios
922+
memories, target_node_count=200, type_ratios=custom_type_ratios
922923
)
924+
# Ensure all node IDs are unique in the tree structure
925+
tree_result = ensure_unique_tree_ids(tree_result)
923926
memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
924927
children = tree_result["children"]
925928
children_sort = sort_children_by_memory_type(children)
@@ -1009,6 +1012,8 @@ def get_subgraph(
10091012
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
10101013
memories, target_node_count=150, type_ratios=custom_type_ratios
10111014
)
1015+
# Ensure all node IDs are unique in the tree structure
1016+
tree_result = ensure_unique_tree_ids(tree_result)
10121017
memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
10131018
children = tree_result["children"]
10141019
children_sort = sort_children_by_memory_type(children)

0 commit comments

Comments
 (0)