Skip to content

Commit 0b0cad8

Browse files
authored
feat: add orginal context for reranking (#284)
* fix: add memory * update: update orginal data * Chore: Change version to v1.0.1 * feat:fix conflict * fix: update memory get * fix: ci code * update: search_reranker * change: rerank_source for reranking * update config
1 parent 5285829 commit 0b0cad8

File tree

4 files changed

+84
-7
lines changed

4 files changed

+84
-7
lines changed

src/memos/api/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ def get_reranker_config() -> dict[str, Any]:
100100
"backend": "http_bge",
101101
"config": {
102102
"url": os.getenv("MOS_RERANKER_URL"),
103-
"model": "bge-reranker-v2-m3",
103+
"model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
104104
"timeout": 10,
105+
"headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
106+
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
105107
},
106108
}
107109
else:

src/memos/reranker/concat.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import re
2+
3+
from typing import Any
4+
5+
6+
_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
7+
8+
9+
def process_source(
10+
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3
11+
) -> str:
12+
"""
13+
Args:
14+
items: List of tuples where each tuple contains (memory, source).
15+
source can be str, Dict, or List.
16+
recent_num: Number of recent items to concatenate.
17+
Returns:
18+
str: Concatenated source.
19+
"""
20+
if items is None:
21+
items = []
22+
concat_data = []
23+
memory = None
24+
for item in items:
25+
memory, source = item
26+
for content in source:
27+
if isinstance(content, str):
28+
if "assistant:" in content:
29+
continue
30+
concat_data.append(content)
31+
if memory is not None:
32+
concat_data = [memory, *concat_data]
33+
return "\n".join(concat_data)
34+
35+
36+
def concat_original_source(
37+
graph_results: list,
38+
merge_field: list[str] | None = None,
39+
) -> list[str]:
40+
"""
41+
Merge memory items with original dialogue.
42+
Args:
43+
graph_results (list[TextualMemoryItem]): List of memory items with embeddings.
44+
merge_field (List[str]): List of fields to merge.
45+
Returns:
46+
list[str]: List of memory and concat orginal memory.
47+
"""
48+
if merge_field is None:
49+
merge_field = ["sources"]
50+
documents = []
51+
for item in graph_results:
52+
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
53+
sources = []
54+
for field in merge_field:
55+
source = getattr(item.metadata, field, "")
56+
sources.append((memory, source))
57+
concat_string = process_source(sources)
58+
documents.append(concat_string)
59+
return documents

src/memos/reranker/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
2929
model=c.get("model", "bge-reranker-v2-m3"),
3030
timeout=int(c.get("timeout", 10)),
3131
headers_extra=c.get("headers_extra"),
32+
rerank_source=c.get("rerank_source"),
3233
)
3334

3435
if backend in {"cosine_local", "cosine"}:

src/memos/reranker/http_bge.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77

88
import requests
99

10+
from memos.log import get_logger
11+
1012
from .base import BaseReranker
13+
from .concat import concat_original_source
14+
15+
16+
logger = get_logger(__name__)
1117

1218

1319
if TYPE_CHECKING:
@@ -28,6 +34,7 @@ def __init__(
2834
model: str = "bge-reranker-v2-m3",
2935
timeout: int = 10,
3036
headers_extra: dict | None = None,
37+
rerank_source: list[str] | None = None,
3138
):
3239
if not reranker_url:
3340
raise ValueError("reranker_url must not be empty")
@@ -36,6 +43,7 @@ def __init__(
3643
self.model = model
3744
self.timeout = timeout
3845
self.headers_extra = headers_extra or {}
46+
self.concat_source = rerank_source
3947

4048
def rerank(
4149
self,
@@ -47,11 +55,18 @@ def rerank(
4755
if not graph_results:
4856
return []
4957

50-
documents = [
51-
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
52-
for item in graph_results
53-
]
54-
documents = [d for d in documents if isinstance(d, str) and d]
58+
documents = []
59+
if self.concat_source:
60+
documents = concat_original_source(graph_results, self.concat_source)
61+
else:
62+
documents = [
63+
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
64+
for item in graph_results
65+
]
66+
documents = [d for d in documents if isinstance(d, str) and d]
67+
68+
logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...")
69+
5570
if not documents:
5671
return []
5772

@@ -95,5 +110,5 @@ def rerank(
95110
return [(item, 0.0) for item in graph_results[:top_k]]
96111

97112
except Exception as e:
98-
print(f"[HTTPBGEReranker] request failed: {e}")
113+
logger.error(f"[HTTPBGEReranker] request failed: {e}")
99114
return [(item, 0.0) for item in graph_results[:top_k]]

0 commit comments

Comments
 (0)