Skip to content

Commit 86e16df

Browse files
authored
Feat: add act mem and chang deafult config (#91)
## Description <!-- Please include a summary of the changes below; Fill in the issue number that this PR addresses (if applicable); Mention the person who will review this PR (if you know who it is); Replace (summary), (issue), and (reviewer) with the appropriate information (No parentheses). 请在下方填写更改的摘要; 填写此 PR 解决的问题编号(如果适用); 提及将审查此 PR 的人(如果您知道是谁); 替换 (summary)、(issue) 和 (reviewer) 为适当的信息(不带括号)。 --> Summary: (summary) Fix: #(issue) Reviewer: @(reviewer) ## Checklist: - [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 - [ ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常 - [ ] I have added necessary documentation (if applicable) | 我已添加必要的文档(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人
2 parents 4e68899 + 1a73dbd commit 86e16df

File tree

10 files changed

+270
-48
lines changed

10 files changed

+270
-48
lines changed

src/memos/api/config.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ def is_scheduler_enabled() -> bool:
130130
"""Check if scheduler is enabled via environment variable."""
131131
return os.getenv("MOS_ENABLE_SCHEDULER", "false").lower() == "true"
132132

133+
@staticmethod
134+
def is_default_cube_config_enabled() -> bool:
135+
"""Check if default cube config is enabled via environment variable."""
136+
return os.getenv("MOS_ENABLE_DEFAULT_CUBE_CONFIG", "false").lower() == "true"
137+
133138
@staticmethod
134139
def get_product_default_config() -> dict[str, Any]:
135140
"""Get default configuration for Product API."""
@@ -321,3 +326,46 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
321326

322327
default_mem_cube = GeneralMemCube(default_cube_config)
323328
return default_config, default_mem_cube
329+
330+
@staticmethod
331+
def get_default_cube_config() -> GeneralMemCubeConfig | None:
332+
"""Get default cube configuration for product initialization.
333+
334+
Returns:
335+
GeneralMemCubeConfig | None: Default cube configuration if enabled, None otherwise.
336+
"""
337+
if not APIConfig.is_default_cube_config_enabled():
338+
return None
339+
340+
openai_config = APIConfig.get_openai_config()
341+
neo4j_config = APIConfig.get_neo4j_config()
342+
343+
return GeneralMemCubeConfig.model_validate(
344+
{
345+
"user_id": "default",
346+
"cube_id": "default_cube",
347+
"text_mem": {
348+
"backend": "tree_text",
349+
"config": {
350+
"extractor_llm": {"backend": "openai", "config": openai_config},
351+
"dispatcher_llm": {"backend": "openai", "config": openai_config},
352+
"graph_db": {
353+
"backend": "neo4j",
354+
"config": neo4j_config,
355+
},
356+
"embedder": {
357+
"backend": "ollama",
358+
"config": {
359+
"model_name_or_path": "nomic-embed-text:latest",
360+
"api_base": os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
361+
},
362+
},
363+
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true",
364+
},
365+
},
366+
"act_mem": {}
367+
if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false"
368+
else APIConfig.get_activation_vllm_config(),
369+
"para_mem": {},
370+
}
371+
)

src/memos/api/routers/product_router.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ def get_mos_product_instance():
4141
from memos.configs.mem_os import MOSConfig
4242

4343
mos_config = MOSConfig(**default_config)
44-
MOS_PRODUCT_INSTANCE = MOSProduct(default_config=mos_config)
44+
45+
# Get default cube config from APIConfig (may be None if disabled)
46+
default_cube_config = APIConfig.get_default_cube_config()
47+
print("*********default_cube_config*********", default_cube_config)
48+
MOS_PRODUCT_INSTANCE = MOSProduct(
49+
default_config=mos_config, default_cube_config=default_cube_config
50+
)
4551
logger.info("MOSProduct instance created successfully with inheritance architecture")
4652
return MOS_PRODUCT_INSTANCE
4753

@@ -68,6 +74,7 @@ async def register_user(user_req: UserRegisterRequest):
6874
logger.info(f"user_config: {user_config.model_dump(mode='json')}")
6975
logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}")
7076
mos_product = get_mos_product_instance()
77+
7178
# Register user with default config and mem cube
7279
result = mos_product.user_register(
7380
user_id=user_req.user_id,

src/memos/mem_cube/general.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from memos.exceptions import ConfigurationError, MemCubeError
88
from memos.log import get_logger
99
from memos.mem_cube.base import BaseMemCube
10-
from memos.mem_cube.utils import download_repo
10+
from memos.mem_cube.utils import download_repo, merge_config_with_default
1111
from memos.memories.activation.base import BaseActMemory
1212
from memos.memories.factory import MemoryFactory
1313
from memos.memories.parametric.base import BaseParaMemory
@@ -114,20 +114,30 @@ def dump(
114114

115115
@staticmethod
116116
def init_from_dir(
117-
dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None
117+
dir: str,
118+
memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
119+
default_config: GeneralMemCubeConfig | None = None,
118120
) -> "GeneralMemCube":
119121
"""Create a MemCube instance from a MemCube directory.
120122
121123
Args:
122124
dir (str): The directory containing the memory files.
123125
memory_types (list[str], optional): List of memory types to load.
124126
If None, loads all available memory types.
127+
default_config (GeneralMemCubeConfig, optional): Default configuration to merge with existing config.
128+
If provided, will merge general settings while preserving critical user-specific fields.
125129
126130
Returns:
127131
MemCube: An instance of MemCube loaded with memories from the specified directory.
128132
"""
129133
config_path = os.path.join(dir, "config.json")
130134
config = GeneralMemCubeConfig.from_json_file(config_path)
135+
136+
# Merge with default config if provided
137+
if default_config is not None:
138+
config = merge_config_with_default(config, default_config)
139+
logger.info(f"Applied default config to cube {config.cube_id}")
140+
131141
mem_cube = GeneralMemCube(config)
132142
mem_cube.load(dir, memory_types)
133143
return mem_cube
@@ -137,6 +147,7 @@ def init_from_remote_repo(
137147
cube_id: str,
138148
base_url: str = "https://huggingface.co/datasets",
139149
memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
150+
default_config: GeneralMemCubeConfig | None = None,
140151
) -> "GeneralMemCube":
141152
"""Create a MemCube instance from a remote repository.
142153
@@ -145,12 +156,13 @@ def init_from_remote_repo(
145156
base_url (str): The base URL of the remote repository.
146157
memory_types (list[str], optional): List of memory types to load.
147158
If None, loads all available memory types.
159+
default_config (GeneralMemCubeConfig, optional): Default configuration to merge with existing config.
148160
149161
Returns:
150162
MemCube: An instance of MemCube loaded with memories from the specified remote repository.
151163
"""
152164
dir = download_repo(cube_id, base_url)
153-
return GeneralMemCube.init_from_dir(dir, memory_types)
165+
return GeneralMemCube.init_from_dir(dir, memory_types, default_config)
154166

155167
@property
156168
def text_mem(self) -> "BaseTextMemory | None":

src/memos/mem_cube/utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
import copy
2+
import logging
13
import subprocess
24
import tempfile
35

6+
from typing import Any
7+
8+
from memos.configs.mem_cube import GeneralMemCubeConfig
9+
10+
11+
logger = logging.getLogger(__name__)
12+
413

514
def download_repo(repo: str, base_url: str, dir: str | None = None) -> str:
615
"""Download a repository from a remote source.
@@ -22,3 +31,96 @@ def download_repo(repo: str, base_url: str, dir: str | None = None) -> str:
2231
subprocess.run(["git", "clone", repo_url, dir], check=True)
2332

2433
return dir
34+
35+
36+
def merge_config_with_default(
37+
existing_config: GeneralMemCubeConfig, default_config: GeneralMemCubeConfig
38+
) -> GeneralMemCubeConfig:
39+
"""
40+
Merge existing cube config with default config, preserving critical fields.
41+
42+
This method updates general configuration fields (like API keys, model parameters)
43+
while preserving critical user-specific fields (like user_id, cube_id, graph_db settings).
44+
45+
Args:
46+
existing_config (GeneralMemCubeConfig): The existing cube configuration loaded from file
47+
default_config (GeneralMemCubeConfig): The default configuration to merge from
48+
49+
Returns:
50+
GeneralMemCubeConfig: Merged configuration
51+
"""
52+
53+
def deep_merge_dicts(
54+
existing: dict[str, Any], default: dict[str, Any], preserve_keys: set[str] | None = None
55+
) -> dict[str, Any]:
56+
"""Recursively merge dictionaries, preserving specified keys from existing dict."""
57+
if preserve_keys is None:
58+
preserve_keys = set()
59+
60+
result = copy.deepcopy(existing)
61+
62+
for key, default_value in default.items():
63+
if key in preserve_keys:
64+
# Preserve existing value for critical keys
65+
continue
66+
67+
if key in result and isinstance(result[key], dict) and isinstance(default_value, dict):
68+
# Recursively merge nested dictionaries
69+
result[key] = deep_merge_dicts(result[key], default_value, preserve_keys)
70+
elif key not in result or result[key] is None:
71+
# Use default value if key doesn't exist or is None
72+
result[key] = copy.deepcopy(default_value)
73+
# For non-dict values, keep existing value unless it's None
74+
75+
return result
76+
77+
# Convert configs to dictionaries
78+
existing_dict = existing_config.model_dump(mode="json")
79+
default_dict = default_config.model_dump(mode="json")
80+
81+
# Merge text_mem config
82+
if "text_mem" in existing_dict and "text_mem" in default_dict:
83+
existing_text_config = existing_dict["text_mem"].get("config", {})
84+
default_text_config = default_dict["text_mem"].get("config", {})
85+
86+
# Handle nested graph_db config specially
87+
if "graph_db" in existing_text_config and "graph_db" in default_text_config:
88+
existing_graph_config = existing_text_config["graph_db"].get("config", {})
89+
default_graph_config = default_text_config["graph_db"].get("config", {})
90+
91+
# Merge graph_db config, preserving critical keys
92+
merged_graph_config = deep_merge_dicts(
93+
existing_graph_config,
94+
default_graph_config,
95+
preserve_keys={"uri", "user", "password", "db_name", "auto_create"},
96+
)
97+
98+
# Update the configs
99+
existing_text_config["graph_db"]["config"] = merged_graph_config
100+
default_text_config["graph_db"]["config"] = merged_graph_config
101+
102+
# Merge other text_mem config fields
103+
merged_text_config = deep_merge_dicts(existing_text_config, default_text_config)
104+
existing_dict["text_mem"]["config"] = merged_text_config
105+
106+
# Merge act_mem config
107+
if "act_mem" in existing_dict and "act_mem" in default_dict:
108+
existing_act_config = existing_dict["act_mem"].get("config", {})
109+
default_act_config = default_dict["act_mem"].get("config", {})
110+
merged_act_config = deep_merge_dicts(existing_act_config, default_act_config)
111+
existing_dict["act_mem"]["config"] = merged_act_config
112+
113+
# Merge para_mem config
114+
if "para_mem" in existing_dict and "para_mem" in default_dict:
115+
existing_para_config = existing_dict["para_mem"].get("config", {})
116+
default_para_config = default_dict["para_mem"].get("config", {})
117+
merged_para_config = deep_merge_dicts(existing_para_config, default_para_config)
118+
existing_dict["para_mem"]["config"] = merged_para_config
119+
120+
# Create new config from merged dictionary
121+
merged_config = GeneralMemCubeConfig.model_validate(existing_dict)
122+
logger.info(
123+
f"Merged cube config for user {merged_config.user_id}, cube {merged_config.cube_id}"
124+
)
125+
126+
return merged_config

0 commit comments

Comments
 (0)