Skip to content

Commit bd450bd

Browse files
authored
feat: update neublar config and change user manager dup field for mysql (#171)
* feat: add user manager factory pattern and product API enhancements - Add user manager factory pattern with SQLite and MySQL backends - Add user manager configuration to MOSConfig - Add product API router and configuration - Add DingDing notification integration - Add notification service utilities - Update OpenAPI documentation * fix: change user manager default * fix: update config * fix: remove db_name for neublar * fix:host * feat:update db * fix:test * fix: reomve dup file * feat: add logs * feat: add users profile field
1 parent f605f79 commit bd450bd

File tree

11 files changed

+164
-57
lines changed

11 files changed

+164
-57
lines changed

src/memos/api/config.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23

34
from typing import Any
@@ -218,6 +219,32 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
218219
"embedding_dimension": 3072,
219220
}
220221

222+
@staticmethod
223+
def get_nebular_config(user_id: str | None = None) -> dict[str, Any]:
224+
"""Get Nebular configuration."""
225+
return {
226+
"uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')),
227+
"user": os.getenv("NEBULAR_USER", "root"),
228+
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
229+
"space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"),
230+
"user_name": f"memos{user_id.replace('-', '')}",
231+
"use_multi_db": False,
232+
"auto_create": True,
233+
"embedding_dimension": 3072,
234+
}
235+
236+
@staticmethod
237+
def get_mysql_config() -> dict[str, Any]:
238+
"""Get MySQL configuration."""
239+
return {
240+
"host": os.getenv("MYSQL_HOST", "localhost"),
241+
"port": int(os.getenv("MYSQL_PORT", "3306")),
242+
"username": os.getenv("MYSQL_USERNAME", "root"),
243+
"password": os.getenv("MYSQL_PASSWORD", "12345678"),
244+
"database": os.getenv("MYSQL_DATABASE", "memos_users"),
245+
"charset": os.getenv("MYSQL_CHARSET", "utf8mb4"),
246+
}
247+
221248
@staticmethod
222249
def get_scheduler_config() -> dict[str, Any]:
223250
"""Get scheduler configuration."""
@@ -294,6 +321,7 @@ def get_product_default_config() -> dict[str, Any]:
294321
"vllm": vllm_config,
295322
}
296323
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
324+
mysql_config = APIConfig.get_mysql_config()
297325
config = {
298326
"user_id": os.getenv("MOS_USER_ID", "root"),
299327
"chat_model": {"backend": backend, "config": backend_model[backend]},
@@ -330,6 +358,13 @@ def get_product_default_config() -> dict[str, Any]:
330358
else:
331359
config["enable_mem_scheduler"] = False
332360

361+
# Add user manager configuration if enabled
362+
if os.getenv("MOS_USER_MANAGER_BACKEND", "sqlite").lower() == "mysql":
363+
config["user_manager"] = {
364+
"backend": "mysql",
365+
"config": mysql_config,
366+
}
367+
333368
return config
334369

335370
@staticmethod
@@ -372,6 +407,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
372407
openai_config = APIConfig.get_openai_config()
373408
qwen_config = APIConfig.qwen_config()
374409
vllm_config = APIConfig.vllm_config()
410+
mysql_config = APIConfig.get_mysql_config()
375411
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
376412
backend_model = {
377413
"openai": openai_config,
@@ -417,10 +453,18 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
417453
else:
418454
config_dict["enable_mem_scheduler"] = False
419455

456+
# Add user manager configuration if enabled
457+
if os.getenv("MOS_USER_MANAGER_BACKEND", "sqlite").lower() == "mysql":
458+
config_dict["user_manager"] = {
459+
"backend": "mysql",
460+
"config": mysql_config,
461+
}
462+
420463
default_config = MOSConfig(**config_dict)
421464

422465
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
423466
neo4j_config = APIConfig.get_neo4j_config(user_id)
467+
nebular_config = APIConfig.get_nebular_config(user_id)
424468
internet_config = (
425469
APIConfig.get_internet_config()
426470
if os.getenv("ENABLE_INTERNET", "false").lower() == "true"
@@ -429,6 +473,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
429473
graph_db_backend_map = {
430474
"neo4j-community": neo4j_community_config,
431475
"neo4j": neo4j_config,
476+
"nebular": nebular_config,
432477
}
433478
graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
434479
if graph_db_backend in graph_db_backend_map:
@@ -475,9 +520,11 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
475520
openai_config = APIConfig.get_openai_config()
476521
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
477522
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
523+
nebular_config = APIConfig.get_nebular_config(user_id="default")
478524
graph_db_backend_map = {
479525
"neo4j-community": neo4j_community_config,
480526
"neo4j": neo4j_config,
527+
"nebular": nebular_config,
481528
}
482529
internet_config = (
483530
APIConfig.get_internet_config()

src/memos/api/product_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class ChatRequest(BaseRequest):
8888

8989
class UserCreate(BaseRequest):
9090
user_name: str | None = Field(None, description="Name of the user")
91-
role: str = Field("user", description="Role of the user")
91+
role: str = Field("USER", description="Role of the user")
9292
user_id: str = Field(..., description="User ID")
9393

9494

@@ -143,6 +143,8 @@ class MemoryCreateRequest(BaseRequest):
143143
memory_content: str | None = Field(None, description="Memory content to store")
144144
doc_path: str | None = Field(None, description="Path to document to store")
145145
mem_cube_id: str | None = Field(None, description="Cube ID")
146+
source: str | None = Field(None, description="Source of the memory")
147+
user_profile: bool = Field(False, description="User profile memory")
146148

147149

148150
class SearchRequest(BaseRequest):

src/memos/api/routers/product_router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def create_memory(memory_req: MemoryCreateRequest):
198198
messages=memory_req.messages,
199199
doc_path=memory_req.doc_path,
200200
mem_cube_id=memory_req.mem_cube_id,
201+
source=memory_req.source,
202+
user_profile=memory_req.user_profile,
201203
)
202204
return SimpleResponse(message="Memory created successfully")
203205

src/memos/mem_cube/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,20 @@ def merge_config_with_default(
9292
merged_graph_config["db_name"] = default_graph_config.get("db_name")
9393
else:
9494
logger.info("use_multi_db is already False, no need to change")
95-
95+
if "neo4j" not in default_text_config["graph_db"]["backend"]:
96+
if "db_name" in merged_graph_config:
97+
merged_graph_config.pop("db_name")
98+
logger.info("neo4j is not supported, remove db_name")
99+
else:
100+
logger.info("db_name is not in merged_graph_config, no need to remove")
101+
else:
102+
if "space" in merged_graph_config:
103+
merged_graph_config.pop("space")
104+
logger.info("neo4j is not supported, remove db_name")
105+
else:
106+
logger.info("space is not in merged_graph_config, no need to remove")
96107
preserved_graph_db = {
97-
"backend": existing_text_config["graph_db"]["backend"],
108+
"backend": default_text_config["graph_db"]["backend"],
98109
"config": merged_graph_config,
99110
}
100111

src/memos/mem_os/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import time
34
import uuid
45

56
from datetime import datetime
@@ -568,6 +569,7 @@ def search(
568569
and (mem_cube.text_mem is not None)
569570
and self.config.enable_textual_memory
570571
):
572+
time_start = time.time()
571573
memories = mem_cube.text_mem.search(
572574
query,
573575
top_k=top_k if top_k else self.config.top_k,
@@ -579,6 +581,10 @@ def search(
579581
logger.info(
580582
f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n"
581583
)
584+
search_time_end = time.time()
585+
logger.info(
586+
f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}"
587+
)
582588
return result
583589

584590
def add(

0 commit comments

Comments
 (0)