Skip to content

Commit 818ab0d

Browse files
authored
Merge branch 'dev' into feat/demo-reference-change
2 parents f6636df + 084b14e commit 818ab0d

File tree

6 files changed

+127
-49
lines changed

6 files changed

+127
-49
lines changed

examples/mem_os/multi_user_memos_example.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Example demonstrating how to use MOSProduct for multi-user scenarios.
33
"""
44

5+
import os
6+
57
from memos.configs.mem_cube import GeneralMemCubeConfig
68
from memos.configs.mem_os import MOSConfig
79
from memos.mem_cube.general import GeneralMemCube
@@ -16,28 +18,53 @@ def get_config(user_name):
1618
"top_p": 0.9,
1719
"top_k": 50,
1820
"remove_think_prefix": True,
19-
"api_key": "your-api-key-here",
20-
"api_base": "https://api.openai.com/v1",
21+
"api_key": os.getenv("OPENAI_API_KEY"),
22+
"api_base": os.getenv("OPENAI_API_BASE"),
2123
}
2224
# Create a default configuration
2325
default_config = MOSConfig(
2426
user_id="root",
2527
chat_model={"backend": "openai", "config": openapi_config},
2628
mem_reader={
27-
"backend": "naive",
29+
"backend": "simple_struct",
2830
"config": {
2931
"llm": {
3032
"backend": "openai",
3133
"config": openapi_config,
3234
},
3335
"embedder": {
34-
"backend": "ollama",
36+
"backend": "universal_api",
37+
"config": {
38+
"provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"),
39+
"api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"),
40+
"model_name_or_path": os.getenv(
41+
"MOS_EMBEDDER_MODEL", "text-embedding-3-large"
42+
),
43+
"base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"),
44+
},
45+
},
46+
"chunker": {
47+
"backend": "sentence",
3548
"config": {
36-
"model_name_or_path": "nomic-embed-text:latest",
49+
"tokenizer_or_token_counter": "gpt2",
50+
"chunk_size": 512,
51+
"chunk_overlap": 128,
52+
"min_sentences_per_chunk": 1,
3753
},
3854
},
3955
},
4056
},
57+
user_manager={
58+
"backend": "mysql",
59+
"config": {
60+
"host": os.getenv("MYSQL_HOST", "localhost"),
61+
"port": int(os.getenv("MYSQL_PORT", "3306")),
62+
"username": os.getenv("MYSQL_USERNAME", "root"),
63+
"password": os.getenv("MYSQL_PASSWORD", "12345678"),
64+
"database": os.getenv("MYSQL_DATABASE", "memos_users"),
65+
"charset": os.getenv("MYSQL_CHARSET", "utf8mb4"),
66+
},
67+
},
4168
enable_textual_memory=True,
4269
enable_activation_memory=False,
4370
top_k=5,
@@ -55,17 +82,27 @@ def get_config(user_name):
5582
"graph_db": {
5683
"backend": "neo4j",
5784
"config": {
58-
"uri": "bolt://localhost:7687",
59-
"user": "neo4j",
60-
"password": "12345678",
61-
"db_name": user_name,
85+
"uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
86+
"user": os.getenv("NEO4J_USER", "neo4j"),
87+
"password": os.getenv("NEO4J_PASSWORD", "12345678"),
88+
"db_name": os.getenv(
89+
"NEO4J_DB_NAME", "shared-tree-textual-memory-test"
90+
),
91+
"user_name": f"memos{user_name.replace('-', '')}",
92+
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 768)),
93+
"use_multi_db": False,
6294
"auto_create": True,
6395
},
6496
},
6597
"embedder": {
66-
"backend": "ollama",
98+
"backend": "universal_api",
6799
"config": {
68-
"model_name_or_path": "nomic-embed-text:latest",
100+
"provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"),
101+
"api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"),
102+
"model_name_or_path": os.getenv(
103+
"MOS_EMBEDDER_MODEL", "text-embedding-3-large"
104+
),
105+
"base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"),
69106
},
70107
},
71108
},
@@ -109,7 +146,7 @@ def main():
109146
print(f"\nSearch result for Alice: {search_result}")
110147

111148
# Search memories for Alice
112-
search_result = mos_product.get_all(query="conference", user_id="alice", memory_type="text_mem")
149+
search_result = mos_product.get_all(user_id="alice", memory_type="text_mem")
113150
print(f"\nSearch result for Alice: {search_result}")
114151

115152
# List all users

src/memos/graph_dbs/nebular.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import traceback
23

34
from contextlib import suppress
@@ -35,7 +36,28 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
3536

3637
@timed
3738
def _escape_str(value: str) -> str:
38-
return value.replace('"', '\\"')
39+
out = []
40+
for ch in value:
41+
code = ord(ch)
42+
if ch == "\\":
43+
out.append("\\\\")
44+
elif ch == '"':
45+
out.append('\\"')
46+
elif ch == "\n":
47+
out.append("\\n")
48+
elif ch == "\r":
49+
out.append("\\r")
50+
elif ch == "\t":
51+
out.append("\\t")
52+
elif ch == "\b":
53+
out.append("\\b")
54+
elif ch == "\f":
55+
out.append("\\f")
56+
elif code < 0x20 or code in (0x2028, 0x2029):
57+
out.append(f"\\u{code:04x}")
58+
else:
59+
out.append(ch)
60+
return "".join(out)
3961

4062

4163
@timed
@@ -1153,28 +1175,36 @@ def import_graph(self, data: dict[str, Any]) -> None:
11531175
data: A dictionary containing all nodes and edges to be loaded.
11541176
"""
11551177
for node in data.get("nodes", []):
1156-
id, memory, metadata = _compose_node(node)
1178+
try:
1179+
id, memory, metadata = _compose_node(node)
11571180

1158-
if not self.config.use_multi_db and self.config.user_name:
1159-
metadata["user_name"] = self.config.user_name
1181+
if not self.config.use_multi_db and self.config.user_name:
1182+
metadata["user_name"] = self.config.user_name
11601183

1161-
metadata = self._prepare_node_metadata(metadata)
1162-
metadata.update({"id": id, "memory": memory})
1163-
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
1164-
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1165-
self.execute_query(node_gql)
1184+
metadata = self._prepare_node_metadata(metadata)
1185+
metadata.update({"id": id, "memory": memory})
1186+
properties = ", ".join(
1187+
f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
1188+
)
1189+
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1190+
self.execute_query(node_gql)
1191+
except Exception as e:
1192+
logger.error(f"Fail to load node: {node}, error: {e}")
11661193

11671194
for edge in data.get("edges", []):
1168-
source_id, target_id = edge["source"], edge["target"]
1169-
edge_type = edge["type"]
1170-
props = ""
1171-
if not self.config.use_multi_db and self.config.user_name:
1172-
props = f'{{user_name: "{self.config.user_name}"}}'
1173-
edge_gql = f'''
1174-
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1175-
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1176-
'''
1177-
self.execute_query(edge_gql)
1195+
try:
1196+
source_id, target_id = edge["source"], edge["target"]
1197+
edge_type = edge["type"]
1198+
props = ""
1199+
if not self.config.use_multi_db and self.config.user_name:
1200+
props = f'{{user_name: "{self.config.user_name}"}}'
1201+
edge_gql = f'''
1202+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1203+
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1204+
'''
1205+
self.execute_query(edge_gql)
1206+
except Exception as e:
1207+
logger.error(f"Fail to load edge: {edge}, error: {e}")
11781208

11791209
@timed
11801210
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
@@ -1555,6 +1585,7 @@ def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
15551585
# Normalize embedding type
15561586
embedding = metadata.get("embedding")
15571587
if embedding and isinstance(embedding, list):
1588+
metadata.pop("embedding")
15581589
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
15591590

15601591
return metadata
@@ -1563,26 +1594,41 @@ def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
15631594
def _format_value(self, val: Any, key: str = "") -> str:
15641595
from nebulagraph_python.py_data_types import NVector
15651596

1597+
# None
1598+
if val is None:
1599+
return "NULL"
1600+
# bool
1601+
if isinstance(val, bool):
1602+
return "true" if val else "false"
1603+
# str
15661604
if isinstance(val, str):
15671605
return f'"{_escape_str(val)}"'
1606+
# num
15681607
elif isinstance(val, (int | float)):
15691608
return str(val)
1609+
# time
15701610
elif isinstance(val, datetime):
15711611
return f'datetime("{val.isoformat()}")'
1612+
# list
15721613
elif isinstance(val, list):
15731614
if key == self.dim_field:
15741615
dim = len(val)
15751616
joined = ",".join(str(float(x)) for x in val)
15761617
return f"VECTOR<{dim}, FLOAT>([{joined}])"
15771618
else:
15781619
return f"[{', '.join(self._format_value(v) for v in val)}]"
1620+
# NVector
15791621
elif isinstance(val, NVector):
15801622
if key == self.dim_field:
15811623
dim = len(val)
15821624
joined = ",".join(str(float(x)) for x in val)
15831625
return f"VECTOR<{dim}, FLOAT>([{joined}])"
1584-
elif val is None:
1585-
return "NULL"
1626+
else:
1627+
logger.warning("Invalid NVector")
1628+
# dict
1629+
if isinstance(val, dict):
1630+
j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
1631+
return f'"{_escape_str(j)}"'
15861632
else:
15871633
return f'"{_escape_str(str(val))}"'
15881634

src/memos/graph_dbs/neo4j.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,12 +323,11 @@ def edge_exists(
323323
return result.single() is not None
324324

325325
# Graph Query & Reasoning
326-
def get_node(self, id: str, include_embedding: bool = True) -> dict[str, Any] | None:
326+
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
327327
"""
328328
Retrieve the metadata and memory of a node.
329329
Args:
330330
id: Node identifier.
331-
include_embedding (bool): Whether to include the large embedding field.
332331
Returns:
333332
Dictionary of node fields, or None if not found.
334333
"""
@@ -345,12 +344,11 @@ def get_node(self, id: str, include_embedding: bool = True) -> dict[str, Any] |
345344
record = session.run(query, params).single()
346345
return self._parse_node(dict(record["n"])) if record else None
347346

348-
def get_nodes(self, ids: list[str], include_embedding: bool = True) -> list[dict[str, Any]]:
347+
def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
349348
"""
350349
Retrieve the metadata and memory of a list of nodes.
351350
Args:
352351
ids: List of Node identifier.
353-
include_embedding (bool): Whether to include the large embedding field.
354352
Returns:
355353
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
356354
@@ -833,7 +831,7 @@ def clear(self) -> None:
833831
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
834832
raise
835833

836-
def export_graph(self, include_embedding: bool = True) -> dict[str, Any]:
834+
def export_graph(self, **kwargs) -> dict[str, Any]:
837835
"""
838836
Export all graph nodes and edges in a structured form.
839837
@@ -914,13 +912,12 @@ def import_graph(self, data: dict[str, Any]) -> None:
914912
target_id=edge["target"],
915913
)
916914

917-
def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> list[dict]:
915+
def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
918916
"""
919917
Retrieve all memory items of a specific memory_type.
920918
921919
Args:
922920
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
923-
include_embedding (bool): Whether to include the large embedding field.
924921
Returns:
925922
926923
Returns:
@@ -946,9 +943,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> li
946943
results = session.run(query, params)
947944
return [self._parse_node(dict(record["n"])) for record in results]
948945

949-
def get_structure_optimization_candidates(
950-
self, scope: str, include_embedding: bool = True
951-
) -> list[dict]:
946+
def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[dict]:
952947
"""
953948
Find nodes that are likely candidates for structure optimization:
954949
- Isolated nodes, nodes with empty background, or nodes with exactly one child.

src/memos/graph_dbs/neo4j_community.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,12 @@ def search_by_embedding(
169169
# Return consistent format
170170
return [{"id": r.id, "score": r.score} for r in results]
171171

172-
def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> list[dict]:
172+
def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
173173
"""
174174
Retrieve all memory items of a specific memory_type.
175175
176176
Args:
177177
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
178-
include_embedding (bool): Whether to include the large embedding field.
179-
180178
Returns:
181179
list[dict]: Full list of memory items under this scope.
182180
"""

src/memos/mem_os/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def get_user_info(self) -> dict[str, Any]:
971971
return {
972972
"user_id": user.user_id,
973973
"user_name": user.user_name,
974-
"role": user.role.value,
974+
"role": user.role.value if hasattr(user.role, "value") else user.role,
975975
"created_at": user.created_at.isoformat(),
976976
"accessible_cubes": [
977977
{

src/memos/mem_user/mysql_user_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class User(Base):
5555

5656
user_id = Column(String(255), primary_key=True, default=lambda: str(uuid.uuid4()))
5757
user_name = Column(String(255), unique=True, nullable=False)
58-
role = Column(String(20), default=UserRole.USER.value, nullable=False)
58+
role = Column(
59+
String(20), default=UserRole.USER.value, nullable=False
60+
) # for sqlite backend this is SQLEnum
5961
created_at = Column(DateTime, default=datetime.now, nullable=False)
6062
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
6163
is_active = Column(Boolean, default=True, nullable=False)
@@ -65,7 +67,7 @@ class User(Base):
6567
owned_cubes = relationship("Cube", back_populates="owner", cascade="all, delete-orphan")
6668

6769
def __repr__(self):
68-
return f"<User(user_id='{self.user_id}', user_name='{self.user_name}', role='{self.role.value}')>"
70+
return f"<User(user_id='{self.user_id}', user_name='{self.user_name}', role='{self.role}')>"
6971

7072

7173
class Cube(Base):

0 commit comments

Comments
 (0)