Skip to content

Commit 7e8ae7c

Browse files
authored
Feat:update embedding (#732)
* feat: update include embedding * feat: update init * feat: update embedding * fix: code * feat: update feedback
1 parent 522432d commit 7e8ae7c

File tree

5 files changed

+48
-6
lines changed

5 files changed

+48
-6
lines changed

src/memos/api/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,32 @@ def get_reranker_config() -> dict[str, Any]:
395395
},
396396
}
397397

398+
@staticmethod
399+
def get_feedback_reranker_config() -> dict[str, Any]:
400+
"""Get embedder configuration."""
401+
embedder_backend = os.getenv("MOS_FEEDBACK_RERANKER_BACKEND", "http_bge")
402+
403+
if embedder_backend in ["http_bge", "http_bge_strategy"]:
404+
return {
405+
"backend": embedder_backend,
406+
"config": {
407+
"url": os.getenv("MOS_RERANKER_URL"),
408+
"model": os.getenv("MOS_FEEDBACK_RERANKER_MODEL", "bge-reranker-v2-m3"),
409+
"timeout": 10,
410+
"headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")),
411+
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
412+
"reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"),
413+
},
414+
}
415+
else:
416+
return {
417+
"backend": "cosine_local",
418+
"config": {
419+
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
420+
"level_field": "background",
421+
},
422+
}
423+
398424
@staticmethod
399425
def get_embedder_config() -> dict[str, Any]:
400426
"""Get embedder configuration."""

src/memos/api/handlers/component_init.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from memos.api.handlers.config_builders import (
1414
build_chat_llm_config,
1515
build_embedder_config,
16+
build_feedback_reranker_config,
1617
build_graph_db_config,
1718
build_internet_retriever_config,
1819
build_llm_config,
@@ -159,6 +160,7 @@ def init_server() -> dict[str, Any]:
159160
embedder_config = build_embedder_config()
160161
mem_reader_config = build_mem_reader_config()
161162
reranker_config = build_reranker_config()
163+
feedback_reranker_config = build_feedback_reranker_config()
162164
internet_retriever_config = build_internet_retriever_config()
163165
vector_db_config = build_vec_db_config()
164166
pref_extractor_config = build_pref_extractor_config()
@@ -179,6 +181,7 @@ def init_server() -> dict[str, Any]:
179181
embedder = EmbedderFactory.from_config(embedder_config)
180182
mem_reader = MemReaderFactory.from_config(mem_reader_config)
181183
reranker = RerankerFactory.from_config(reranker_config)
184+
feedback_reranker = RerankerFactory.from_config(feedback_reranker_config)
182185
internet_retriever = InternetRetrieverFactory.from_config(
183186
internet_retriever_config, embedder=embedder
184187
)
@@ -305,7 +308,7 @@ def init_server() -> dict[str, Any]:
305308
memory_manager=memory_manager,
306309
mem_reader=mem_reader,
307310
searcher=searcher,
308-
reranker=reranker,
311+
reranker=feedback_reranker,
309312
)
310313

311314
# Initialize Scheduler

src/memos/api/handlers/config_builders.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,16 @@ def build_reranker_config() -> dict[str, Any]:
140140
return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config())
141141

142142

143+
def build_feedback_reranker_config() -> dict[str, Any]:
144+
"""
145+
Build reranker configuration.
146+
147+
Returns:
148+
Validated reranker configuration dictionary
149+
"""
150+
return RerankerConfigFactory.model_validate(APIConfig.get_feedback_reranker_config())
151+
152+
143153
def build_internet_retriever_config() -> dict[str, Any]:
144154
"""
145155
Build internet retriever configuration.

src/memos/graph_dbs/polardb.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,13 +1160,15 @@ def get_nodes(
11601160
properties = properties_json if properties_json else {}
11611161

11621162
# Parse embedding from JSONB if it exists
1163-
if embedding_json is not None:
1163+
if embedding_json is not None and kwargs.get("include_embedding"):
11641164
try:
11651165
# remove embedding
1166-
"""
1167-
embedding = json.loads(embedding_json) if isinstance(embedding_json, str) else embedding_json
1168-
# properties["embedding"] = embedding
1169-
"""
1166+
embedding = (
1167+
json.loads(embedding_json)
1168+
if isinstance(embedding_json, str)
1169+
else embedding_json
1170+
)
1171+
properties["embedding"] = embedding
11701172
except (json.JSONDecodeError, TypeError):
11711173
logger.warning(f"Failed to parse embedding for node {node_id}")
11721174
nodes.append(

src/memos/memories/textual/tree.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def get_searcher(
144144
manual_close_internet=manual_close_internet,
145145
process_llm=process_llm,
146146
tokenizer=self.tokenizer,
147+
include_embedding=self.include_embedding,
147148
)
148149
return searcher
149150

0 commit comments

Comments
 (0)