Skip to content

Commit bfc32d8

Browse files
authored
Feat:change deafult config for product (#108)
* feat: update config * fix:dim * change dim * fix:change default db * fix:delay * fix:len
1 parent 2387f6d commit bfc32d8

File tree

4 files changed

+126
-104
lines changed

4 files changed

+126
-104
lines changed

src/memos/api/config.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,64 @@ def get_activation_vllm_config() -> dict[str, Any]:
9090
}
9191

9292
@staticmethod
93-
def get_neo4j_config() -> dict[str, Any]:
93+
def get_embedder_config() -> dict[str, Any]:
94+
"""Get embedder configuration."""
95+
embedder_backend = os.getenv("MOS_EMBEDDER_BACKEND", "ollama")
96+
97+
if embedder_backend == "universal_api":
98+
return {
99+
"backend": "universal_api",
100+
"config": {
101+
"provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"),
102+
"api_key": os.getenv("OPENAI_API_KEY", "sk-xxxx"),
103+
"model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"),
104+
"base_url": os.getenv("OPENAI_API_BASE", "http://openai.com"),
105+
},
106+
}
107+
else: # ollama
108+
return {
109+
"backend": "ollama",
110+
"config": {
111+
"model_name_or_path": os.getenv(
112+
"MOS_EMBEDDER_MODEL", "nomic-embed-text:latest"
113+
),
114+
"api_base": os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
115+
},
116+
}
117+
118+
@staticmethod
119+
def get_neo4j_config(user_id: str | None = None) -> dict[str, Any]:
120+
"""Get Neo4j configuration."""
121+
if os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true":
122+
return APIConfig.get_neo4j_shared_config(user_id)
123+
else:
124+
return APIConfig.get_noshared_neo4j_config(user_id)
125+
126+
@staticmethod
127+
def get_noshared_neo4j_config(user_id) -> dict[str, Any]:
94128
"""Get Neo4j configuration."""
95129
return {
96130
"uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
97131
"user": os.getenv("NEO4J_USER", "neo4j"),
132+
"db_name": f"memos{user_id.replace('-', '')}",
98133
"password": os.getenv("NEO4J_PASSWORD", "12345678"),
99134
"auto_create": True,
135+
"use_multi_db": True,
136+
"embedding_dimension": 3072,
137+
}
138+
139+
@staticmethod
140+
def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
141+
"""Get Neo4j configuration."""
142+
return {
143+
"uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
144+
"user": os.getenv("NEO4J_USER", "neo4j"),
145+
"db_name": os.getenv("NEO4J_DB_NAME", "shared-tree-textual-memory"),
146+
"password": os.getenv("NEO4J_PASSWORD", "12345678"),
147+
"user_name": f"memos{user_id.replace('-', '')}",
148+
"auto_create": True,
149+
"use_multi_db": False,
150+
"embedding_dimension": 3072,
100151
}
101152

102153
@staticmethod
@@ -157,13 +208,7 @@ def get_product_default_config() -> dict[str, Any]:
157208
"backend": "openai",
158209
"config": openai_config,
159210
},
160-
"embedder": {
161-
"backend": "ollama",
162-
"config": {
163-
"model_name_or_path": "nomic-embed-text:latest",
164-
"api_base": os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
165-
},
166-
},
211+
"embedder": APIConfig.get_embedder_config(),
167212
"chunker": {
168213
"backend": "sentence",
169214
"config": {
@@ -229,7 +274,7 @@ def get_start_default_config() -> dict[str, Any]:
229274
def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]:
230275
"""Create configuration for a specific user."""
231276
openai_config = APIConfig.get_openai_config()
232-
neo4j_config = APIConfig.get_neo4j_config()
277+
neo4j_config = APIConfig.get_neo4j_config(user_id)
233278
qwen_config = APIConfig.qwen_config()
234279
vllm_config = APIConfig.vllm_config()
235280
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
@@ -252,13 +297,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
252297
"backend": "openai",
253298
"config": openai_config,
254299
},
255-
"embedder": {
256-
"backend": "ollama",
257-
"config": {
258-
"model_name_or_path": "nomic-embed-text:latest",
259-
"api_base": os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
260-
},
261-
},
300+
"embedder": APIConfig.get_embedder_config(),
262301
"chunker": {
263302
"backend": "sentence",
264303
"config": {
@@ -298,23 +337,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
298337
"dispatcher_llm": {"backend": "openai", "config": openai_config},
299338
"graph_db": {
300339
"backend": "neo4j",
301-
"config": {
302-
"uri": neo4j_config["uri"],
303-
"user": neo4j_config["user"],
304-
"password": neo4j_config["password"],
305-
"db_name": os.getenv(
306-
"NEO4J_DB_NAME", f"memos{user_id.replace('-', '')}"
307-
), # , replace with
308-
"auto_create": neo4j_config["auto_create"],
309-
},
310-
},
311-
"embedder": {
312-
"backend": "ollama",
313-
"config": {
314-
"model_name_or_path": "nomic-embed-text:latest",
315-
"api_base": os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
316-
},
340+
"config": neo4j_config,
317341
},
342+
"embedder": APIConfig.get_embedder_config(),
318343
},
319344
},
320345
"act_mem": {}
@@ -338,7 +363,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
338363
return None
339364

340365
openai_config = APIConfig.get_openai_config()
341-
neo4j_config = APIConfig.get_neo4j_config()
366+
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
342367

343368
return GeneralMemCubeConfig.model_validate(
344369
{
@@ -353,13 +378,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
353378
"backend": "neo4j",
354379
"config": neo4j_config,
355380
},
356-
"embedder": {
357-
"backend": "ollama",
358-
"config": {
359-
"model_name_or_path": "nomic-embed-text:latest",
360-
"api_base": os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
361-
},
362-
},
381+
"embedder": APIConfig.get_embedder_config(),
363382
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true",
364383
},
365384
},

src/memos/api/routers/product_router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def generate_chat_response():
222222
history=chat_req.history,
223223
):
224224
yield chunk
225-
await asyncio.sleep(0.05) # 50ms delay between chunks
225+
await asyncio.sleep(0.00001) # 50ms delay between chunks
226226
except Exception as e:
227227
logger.error(f"Error in chat stream: {e}")
228228
error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"

src/memos/mem_cube/utils.py

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import subprocess
44
import tempfile
55

6-
from typing import Any
7-
86
from memos.configs.mem_cube import GeneralMemCubeConfig
97

108

@@ -50,77 +48,79 @@ def merge_config_with_default(
5048
GeneralMemCubeConfig: Merged configuration
5149
"""
5250

53-
def deep_merge_dicts(
54-
existing: dict[str, Any], default: dict[str, Any], preserve_keys: set[str] | None = None
55-
) -> dict[str, Any]:
56-
"""Recursively merge dictionaries, preserving specified keys from existing dict."""
57-
if preserve_keys is None:
58-
preserve_keys = set()
59-
60-
result = copy.deepcopy(existing)
61-
62-
for key, default_value in default.items():
63-
if key in preserve_keys:
64-
# Preserve existing value for critical keys
65-
continue
66-
67-
if key in result and isinstance(result[key], dict) and isinstance(default_value, dict):
68-
# Recursively merge nested dictionaries
69-
result[key] = deep_merge_dicts(result[key], default_value, preserve_keys)
70-
elif key not in result or result[key] is None:
71-
# Use default value if key doesn't exist or is None
72-
result[key] = copy.deepcopy(default_value)
73-
# For non-dict values, keep existing value unless it's None
74-
75-
return result
76-
7751
# Convert configs to dictionaries
7852
existing_dict = existing_config.model_dump(mode="json")
7953
default_dict = default_config.model_dump(mode="json")
8054

81-
# Merge text_mem config
55+
logger.info(
56+
f"Starting config merge for user {existing_config.user_id}, cube {existing_config.cube_id}"
57+
)
58+
59+
# Define fields that should be preserved from existing config
60+
preserve_fields = {"user_id", "cube_id", "config_filename", "model_schema"}
61+
62+
# Preserve graph_db from existing config if it exists, but merge some fields
63+
preserved_graph_db = None
8264
if "text_mem" in existing_dict and "text_mem" in default_dict:
8365
existing_text_config = existing_dict["text_mem"].get("config", {})
8466
default_text_config = default_dict["text_mem"].get("config", {})
8567

86-
# Handle nested graph_db config specially
8768
if "graph_db" in existing_text_config and "graph_db" in default_text_config:
88-
existing_graph_config = existing_text_config["graph_db"].get("config", {})
89-
default_graph_config = default_text_config["graph_db"].get("config", {})
90-
91-
# Merge graph_db config, preserving critical keys
92-
merged_graph_config = deep_merge_dicts(
93-
existing_graph_config,
94-
default_graph_config,
95-
preserve_keys={"uri", "user", "password", "db_name", "auto_create"},
96-
)
97-
98-
# Update the configs
99-
existing_text_config["graph_db"]["config"] = merged_graph_config
100-
default_text_config["graph_db"]["config"] = merged_graph_config
101-
102-
# Merge other text_mem config fields
103-
merged_text_config = deep_merge_dicts(existing_text_config, default_text_config)
104-
existing_dict["text_mem"]["config"] = merged_text_config
105-
106-
# Merge act_mem config
107-
if "act_mem" in existing_dict and "act_mem" in default_dict:
108-
existing_act_config = existing_dict["act_mem"].get("config", {})
109-
default_act_config = default_dict["act_mem"].get("config", {})
110-
merged_act_config = deep_merge_dicts(existing_act_config, default_act_config)
111-
existing_dict["act_mem"]["config"] = merged_act_config
112-
113-
# Merge para_mem config
114-
if "para_mem" in existing_dict and "para_mem" in default_dict:
115-
existing_para_config = existing_dict["para_mem"].get("config", {})
116-
default_para_config = default_dict["para_mem"].get("config", {})
117-
merged_para_config = deep_merge_dicts(existing_para_config, default_para_config)
118-
existing_dict["para_mem"]["config"] = merged_para_config
69+
existing_graph_config = existing_text_config["graph_db"]["config"]
70+
default_graph_config = default_text_config["graph_db"]["config"]
71+
72+
# Define graph_db fields to preserve (user-specific)
73+
preserve_graph_fields = {
74+
"uri",
75+
"user",
76+
"password",
77+
"db_name",
78+
"auto_create",
79+
"user_name",
80+
"use_multi_db",
81+
}
82+
83+
# Create merged graph_db config
84+
merged_graph_config = copy.deepcopy(existing_graph_config)
85+
for key, value in default_graph_config.items():
86+
if key not in preserve_graph_fields:
87+
merged_graph_config[key] = value
88+
logger.debug(
89+
f"Updated graph_db field '{key}': {existing_graph_config.get(key)} -> {value}"
90+
)
91+
if not default_graph_config.get("use_multi_db", True):
92+
# set original use_multi_db to False if default_graph_config.use_multi_db is False
93+
if merged_graph_config.get("use_multi_db", True):
94+
merged_graph_config["use_multi_db"] = False
95+
merged_graph_config["user_name"] = merged_graph_config.get("db_name")
96+
merged_graph_config["db_name"] = default_graph_config.get("db_name")
97+
else:
98+
logger.info("use_multi_db is already False, no need to change")
99+
100+
preserved_graph_db = {
101+
"backend": existing_text_config["graph_db"]["backend"],
102+
"config": merged_graph_config,
103+
}
104+
105+
# Use default config as base
106+
merged_dict = copy.deepcopy(default_dict)
107+
108+
# Restore preserved fields from existing config
109+
for field in preserve_fields:
110+
if field in existing_dict:
111+
merged_dict[field] = existing_dict[field]
112+
logger.debug(f"Preserved field '{field}': {existing_dict[field]}")
113+
114+
# Restore graph_db if it was preserved
115+
if preserved_graph_db and "text_mem" in merged_dict:
116+
merged_dict["text_mem"]["config"]["graph_db"] = preserved_graph_db
117+
logger.debug(f"Preserved graph_db with merged config: {preserved_graph_db}")
119118

120119
# Create new config from merged dictionary
121-
merged_config = GeneralMemCubeConfig.model_validate(existing_dict)
120+
merged_config = GeneralMemCubeConfig.model_validate(merged_dict)
121+
122122
logger.info(
123-
f"Merged cube config for user {merged_config.user_id}, cube {merged_config.cube_id}"
123+
f"Successfully merged cube config for user {merged_config.user_id}, cube {merged_config.cube_id}"
124124
)
125125

126126
return merged_config

src/memos/mem_os/product.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ def register_mem_cube(
544544
)
545545

546546
# Register the MemCube
547+
logger.info(
548+
f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}"
549+
)
547550
self.mem_cubes[mem_cube_id] = mem_cube
548551

549552
def user_register(
@@ -769,7 +772,7 @@ def chat_with_references(
769772
current_messages = [
770773
{"role": "system", "content": system_prompt},
771774
*chat_history.chat_history,
772-
{"role": "user", "content": query + "/nothink"},
775+
{"role": "user", "content": query},
773776
]
774777

775778
# Generate response with custom prompt
@@ -879,7 +882,7 @@ def chat_with_references(
879882
mem_cube_id=cube_id,
880883
)
881884
# Keep chat history under 30 messages by removing oldest conversation pair
882-
if len(self.chat_history_manager[user_id].chat_history) > 30:
885+
if len(self.chat_history_manager[user_id].chat_history) > 10:
883886
self.chat_history_manager[user_id].chat_history.pop(0) # Remove oldest user message
884887
self.chat_history_manager[user_id].chat_history.pop(
885888
0

0 commit comments

Comments
 (0)