Skip to content

Commit 5600dfe

Browse files
committed
feat: pass reranker from tree config
1 parent f799237 commit 5600dfe

File tree

4 files changed

+26
-123
lines changed

4 files changed

+26
-123
lines changed

examples/basic_modules/tree_textual_memory_reranker.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

src/memos/configs/memory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from memos.configs.graph_db import GraphDBConfigFactory
88
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
99
from memos.configs.llm import LLMConfigFactory
10+
from memos.configs.reranker import RerankerConfigFactory
1011
from memos.configs.vec_db import VectorDBConfigFactory
1112
from memos.exceptions import ConfigurationError
1213

@@ -151,6 +152,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
151152
default_factory=EmbedderConfigFactory,
152153
description="Embedder configuration for the memory embedding",
153154
)
155+
reranker: RerankerConfigFactory | None = Field(
156+
None,
157+
description="Reranker configuration (optional, defaults to cosine_local).",
158+
)
154159
graph_db: GraphDBConfigFactory = Field(
155160
...,
156161
default_factory=GraphDBConfigFactory,

src/memos/memories/textual/tree.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any
99

1010
from memos.configs.memory import TreeTextMemoryConfig
11+
from memos.configs.reranker import RerankerConfigFactory
1112
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
1213
from memos.graph_dbs.factory import GraphStoreFactory, Neo4jGraphDB
1314
from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM
@@ -19,6 +20,7 @@
1920
InternetRetrieverFactory,
2021
)
2122
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
23+
from memos.reranker.factory import RerankerFactory
2224
from memos.types import MessageList
2325

2426

@@ -39,6 +41,20 @@ def __init__(self, config: TreeTextMemoryConfig):
3941
)
4042
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
4143
self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db)
44+
if config.reranker is None:
45+
default_cfg = RerankerConfigFactory.model_validate(
46+
{
47+
"backend": "cosine_local",
48+
"config": {
49+
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
50+
"level_field": "background",
51+
},
52+
}
53+
)
54+
self.reranker = RerankerFactory.from_config(default_cfg)
55+
else:
56+
self.reranker = RerankerFactory.from_config(config.reranker)
57+
4258
self.is_reorganize = config.reorganize
4359

4460
self.memory_manager: MemoryManager = MemoryManager(
@@ -131,6 +147,7 @@ def search(
131147
self.dispatcher_llm,
132148
self.graph_store,
133149
self.embedder,
150+
self.reranker,
134151
internet_retriever=None,
135152
moscube=moscube,
136153
)
@@ -139,6 +156,7 @@ def search(
139156
self.dispatcher_llm,
140157
self.graph_store,
141158
self.embedder,
159+
self.reranker,
142160
internet_retriever=self.internet_retriever,
143161
moscube=moscube,
144162
)

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
99
from memos.log import get_logger
1010
from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem
11+
from memos.reranker.base import BaseReranker
1112
from memos.utils import timed
1213

1314
from .internet_retriever_factory import InternetRetrieverFactory
1415
from .reasoner import MemoryReasoner
1516
from .recall import GraphMemoryRetriever
16-
from .reranker import MemoryReranker
1717
from .task_goal_parser import TaskGoalParser
1818

1919

@@ -26,6 +26,7 @@ def __init__(
2626
dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM,
2727
graph_store: Neo4jGraphDB,
2828
embedder: OllamaEmbedder,
29+
reranker: BaseReranker,
2930
internet_retriever: InternetRetrieverFactory | None = None,
3031
moscube: bool = False,
3132
):
@@ -34,7 +35,7 @@ def __init__(
3435

3536
self.task_goal_parser = TaskGoalParser(dispatcher_llm)
3637
self.graph_retriever = GraphMemoryRetriever(self.graph_store, self.embedder)
37-
self.reranker = MemoryReranker(dispatcher_llm, self.embedder)
38+
self.reranker = reranker
3839
self.reasoner = MemoryReasoner(dispatcher_llm)
3940

4041
# Create internet retriever from config if provided

0 commit comments

Comments
 (0)