Skip to content

Commit 3c01d1e

Browse files
authored
Feat: include embedding config (#726)
feat: update include embedding
1 parent a82149f commit 3c01d1e

File tree

9 files changed

+44
-7
lines changed

9 files changed

+44
-7
lines changed

src/memos/api/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
887887
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
888888
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
889889
},
890+
"include_embedding": bool(
891+
os.getenv("INCLUDE_EMBEDDING", "false") == "true"
892+
),
890893
},
891894
},
892895
"act_mem": {}
@@ -960,6 +963,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
960963
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
961964
},
962965
"mode": os.getenv("ASYNC_MODE", "sync"),
966+
"include_embedding": bool(
967+
os.getenv("INCLUDE_EMBEDDING", "false") == "true"
968+
),
963969
},
964970
},
965971
"act_mem": {}

src/memos/api/handlers/component_init.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def init_server() -> dict[str, Any]:
210210
config=default_cube_config.text_mem.config,
211211
internet_retriever=internet_retriever,
212212
tokenizer=tokenizer,
213+
include_embedding=bool(os.getenv("INCLUDE_EMBEDDING", "false") == "true"),
213214
)
214215

215216
logger.debug("Text memory initialized")

src/memos/configs/memory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
196196
default="sync",
197197
description=("whether use asynchronous mode in memory add"),
198198
)
199+
include_embedding: bool | None = Field(
200+
default=False,
201+
description="Whether to include embedding in the memory retrieval",
202+
)
199203

200204

201205
class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig):

src/memos/graph_dbs/polardb.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3602,6 +3602,11 @@ def _build_node_from_agtype(self, node_agtype, embedding=None):
36023602
return None
36033603

36043604
if embedding is not None:
3605+
if isinstance(embedding, str):
3606+
try:
3607+
embedding = json.loads(embedding)
3608+
except (json.JSONDecodeError, TypeError):
3609+
logger.warning("Failed to parse embedding for node")
36053610
props["embedding"] = embedding
36063611

36073612
# Return standard format directly

src/memos/memories/textual/simple_tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
internet_retriever: None = None,
3838
is_reorganize: bool = False,
3939
tokenizer: FastTokenizer | None = None,
40+
include_embedding: bool = False,
4041
):
4142
"""Initialize memory with the given configuration."""
4243
self.config: TreeTextMemoryConfig = config
@@ -65,3 +66,4 @@ def __init__(
6566
)
6667
else:
6768
logger.info("No internet retriever configured")
69+
self.include_embedding = include_embedding

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(self, config: TreeTextMemoryConfig):
9292
else:
9393
logger.info("No internet retriever configured")
9494
self.tokenizer = None
95+
self.include_embedding = config.include_embedding or False
9596

9697
def add(
9798
self,
@@ -192,6 +193,7 @@ def search(
192193
search_strategy=self.search_strategy,
193194
manual_close_internet=manual_close_internet,
194195
tokenizer=self.tokenizer,
196+
include_embedding=self.include_embedding,
195197
)
196198
return searcher.search(
197199
query,

src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
manual_close_internet: bool = True,
3636
process_llm: Any | None = None,
3737
tokenizer: FastTokenizer | None = None,
38+
include_embedding: bool = False,
3839
):
3940
super().__init__(
4041
dispatcher_llm=dispatcher_llm,
@@ -46,6 +47,7 @@ def __init__(
4647
search_strategy=search_strategy,
4748
manual_close_internet=manual_close_internet,
4849
tokenizer=tokenizer,
50+
include_embedding=include_embedding,
4951
)
5052

5153
self.stage_retrieve_top = 3

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ def __init__(
2222
graph_store: Neo4jGraphDB,
2323
embedder: OllamaEmbedder,
2424
bm25_retriever: EnhancedBM25 | None = None,
25+
include_embedding: bool = False,
2526
):
2627
self.graph_store = graph_store
2728
self.embedder = embedder
2829
self.bm25_retriever = bm25_retriever
2930
self.max_workers = 10
3031
self.filter_weight = 0.6
3132
self.use_bm25 = bool(self.bm25_retriever)
33+
self.include_embedding = include_embedding
3234

3335
def retrieve(
3436
self,
@@ -72,7 +74,7 @@ def retrieve(
7274
# For working memory, retrieve all entries (no session-oriented filtering)
7375
working_memories = self.graph_store.get_all_memory_items(
7476
scope="WorkingMemory",
75-
include_embedding=False,
77+
include_embedding=self.include_embedding,
7678
user_name=user_name,
7779
filter=search_filter,
7880
)
@@ -244,7 +246,9 @@ def process_node(node):
244246
return []
245247

246248
# Load nodes and post-filter
247-
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)
249+
node_dicts = self.graph_store.get_nodes(
250+
list(candidate_ids), include_embedding=self.include_embedding
251+
)
248252

249253
final_nodes = []
250254
for node in node_dicts:
@@ -291,7 +295,7 @@ def process_node(node):
291295

292296
# Load nodes and post-filter
293297
node_dicts = self.graph_store.get_nodes(
294-
list(candidate_ids), include_embedding=False, user_name=user_name
298+
list(candidate_ids), include_embedding=self.include_embedding, user_name=user_name
295299
)
296300

297301
final_nodes = []
@@ -385,7 +389,10 @@ def search_path_b():
385389
unique_ids = {r["id"] for r in all_hits if r.get("id")}
386390
node_dicts = (
387391
self.graph_store.get_nodes(
388-
list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name
392+
list(unique_ids),
393+
include_embedding=self.include_embedding,
394+
cube_name=cube_name,
395+
user_name=user_name,
389396
)
390397
or []
391398
)
@@ -416,7 +423,9 @@ def _bm25_recall(
416423
key_filters.append({"field": key, "op": "=", "value": value})
417424
corpus_name += "".join(list(search_filter.values()))
418425
candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name)
419-
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)
426+
node_dicts = self.graph_store.get_nodes(
427+
list(candidate_ids), include_embedding=self.include_embedding
428+
)
420429

421430
bm25_query = " ".join(list({query, *parsed_goal.keys}))
422431
bm25_results = self.bm25_retriever.search(
@@ -471,7 +480,10 @@ def _fulltext_recall(
471480
unique_ids = {r["id"] for r in all_hits if r.get("id")}
472481
node_dicts = (
473482
self.graph_store.get_nodes(
474-
list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name
483+
list(unique_ids),
484+
include_embedding=self.include_embedding,
485+
cube_name=cube_name,
486+
user_name=user_name,
475487
)
476488
or []
477489
)

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ def __init__(
4747
search_strategy: dict | None = None,
4848
manual_close_internet: bool = True,
4949
tokenizer: FastTokenizer | None = None,
50+
include_embedding: bool = False,
5051
):
5152
self.graph_store = graph_store
5253
self.embedder = embedder
5354
self.llm = dispatcher_llm
5455

5556
self.task_goal_parser = TaskGoalParser(dispatcher_llm)
56-
self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever)
57+
self.graph_retriever = GraphMemoryRetriever(
58+
graph_store, embedder, bm25_retriever, include_embedding=include_embedding
59+
)
5760
self.reranker = reranker
5861
self.reasoner = MemoryReasoner(dispatcher_llm)
5962

0 commit comments

Comments
 (0)