Skip to content

Commit c688ead

Browse files
authored
feat: chat bot api (#302)
* fix: add safe guard when parsing node memory * feat: add filter as a parameter in tree-text searcher * feat: add filter for user and long-term memory * feat: add filter in working memory * add filter in task-parser * feat: only mix-retrieve for vector-recall; TODO: mix reranker * feat: add 'session_id' as an optional parameter for product api * feat: api 1.0 finish * maintain: update gitignore * maintain: update gitignore * feat: add 'type' in TextualMemory Sources * feat: add annotation to item * fix: add session_id to product add * fix: test * feat: [WIP] add filter in reranker * fix: bug in recall
1 parent 8f87b33 commit c688ead

File tree

4 files changed

+83
-3
lines changed

4 files changed

+83
-3
lines changed

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ def search_single(vec, filt=None):
217217
all_hits = []
218218
# Path A: without filter
219219
with ContextThreadPoolExecutor() as executor:
220-
futures = [ex.submit(search_single, vec, None) for vec in query_embedding[:max_num]]
220+
futures = [
221+
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
222+
]
221223
for f in concurrent.futures.as_completed(futures):
222224
all_hits.extend(f.result() or [])
223225

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def _retrieve_from_working_memory(
253253
graph_results=items,
254254
top_k=top_k,
255255
parsed_goal=parsed_goal,
256+
search_filter=search_filter,
256257
)
257258

258259
# --- Path B
@@ -292,6 +293,7 @@ def _retrieve_from_long_term_and_user(
292293
graph_results=results,
293294
top_k=top_k,
294295
parsed_goal=parsed_goal,
296+
search_filter=search_filter,
295297
)
296298

297299
@timed

src/memos/reranker/cosine_local.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
self,
5050
level_weights: dict[str, float] | None = None,
5151
level_field: str = "background",
52+
**kwargs,
5253
):
5354
self.level_weights = level_weights or {"topic": 1.0, "concept": 1.0, "fact": 1.0}
5455
self.level_field = level_field

src/memos/reranker/http_bge.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,35 @@
1919
if TYPE_CHECKING:
2020
from memos.memories.textual.item import TextualMemoryItem
2121

22+
# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...")
23+
# before sending text to the reranker. This keeps inputs clean and
24+
# avoids misleading the model with bracketed prefixes.
2225
_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
2326

2427

2528
class HTTPBGEReranker(BaseReranker):
2629
"""
27-
HTTP-based BGE reranker. Mirrors your old MemoryReranker, but configurable.
30+
HTTP-based BGE reranker.
31+
32+
This class sends (query, documents[]) to a remote HTTP endpoint that
33+
performs cross-encoder-style re-ranking (e.g., BGE reranker) and returns
34+
relevance scores. It then maps those scores back onto the original
35+
TextualMemoryItem list and returns (item, score) pairs sorted by score.
36+
37+
Notes
38+
-----
39+
- The endpoint is expected to accept JSON:
40+
{
41+
"model": "<model-name>",
42+
"query": "<query text>",
43+
"documents": ["doc1", "doc2", ...]
44+
}
45+
- Two response shapes are supported:
46+
1) {"results": [{"index": <int>, "relevance_score": <float>}, ...]}
47+
where "index" refers to the *position in the documents array*.
48+
2) {"data": [{"score": <float>}, ...]} (aligned by list order)
49+
- If the service fails or responds unexpectedly, this falls back to
50+
returning the original items with 0.0 scores (best-effort).
2851
"""
2952

3053
def __init__(
@@ -35,7 +58,22 @@ def __init__(
3558
timeout: int = 10,
3659
headers_extra: dict | None = None,
3760
rerank_source: list[str] | None = None,
61+
**kwargs,
3862
):
63+
"""
64+
Parameters
65+
----------
66+
reranker_url : str
67+
HTTP endpoint for the reranker service.
68+
token : str, optional
69+
Bearer token for auth. If non-empty, added to the Authorization header.
70+
model : str, optional
71+
Model identifier understood by the server.
72+
timeout : int, optional
73+
Request timeout (seconds).
74+
headers_extra : dict | None, optional
75+
Additional headers to merge into the request headers.
76+
"""
3977
if not reranker_url:
4078
raise ValueError("reranker_url must not be empty")
4179
self.reranker_url = reranker_url
@@ -48,13 +86,37 @@ def __init__(
4886
def rerank(
4987
self,
5088
query: str,
51-
graph_results: list,
89+
graph_results: list[TextualMemoryItem],
5290
top_k: int,
91+
search_filter: dict | None = None,
5392
**kwargs,
5493
) -> list[tuple[TextualMemoryItem, float]]:
94+
"""
95+
Rank candidate memories by relevance to the query.
96+
97+
Parameters
98+
----------
99+
query : str
100+
The search query.
101+
graph_results : list[TextualMemoryItem]
102+
Candidate items to re-rank. Each item is expected to have a
103+
`.memory` str field; non-strings are ignored.
104+
top_k : int
105+
Return at most this many items.
106+
search_filter : dict | None
107+
Currently unused. Present to keep signature compatible.
108+
109+
Returns
110+
-------
111+
list[tuple[TextualMemoryItem, float]]
112+
Re-ranked items with scores, sorted descending by score.
113+
"""
55114
if not graph_results:
56115
return []
57116

117+
# Build a mapping from "payload docs index" -> "original graph_results index"
118+
# Only include items that have a non-empty string memory. This ensures that
119+
# any index returned by the server can be mapped back correctly.
58120
documents = []
59121
if self.concat_source:
60122
documents = concat_original_source(graph_results, self.concat_source)
@@ -74,6 +136,7 @@ def rerank(
74136
payload = {"model": self.model, "query": query, "documents": documents}
75137

76138
try:
139+
# Make the HTTP request to the reranker service
77140
resp = requests.post(
78141
self.reranker_url, headers=headers, json=payload, timeout=self.timeout
79142
)
@@ -83,9 +146,14 @@ def rerank(
83146
scored_items: list[tuple[TextualMemoryItem, float]] = []
84147

85148
if "results" in data:
149+
# Format:
150+
# dict("results": [{"index": int, "relevance_score": float},
151+
# ...])
86152
rows = data.get("results", [])
87153
for r in rows:
88154
idx = r.get("index")
155+
# The returned index refers to 'documents' (i.e., our 'pairs' order),
156+
# so we must map it back to the original graph_results index.
89157
if isinstance(idx, int) and 0 <= idx < len(graph_results):
90158
score = float(r.get("relevance_score", r.get("score", 0.0)))
91159
scored_items.append((graph_results[idx], score))
@@ -94,21 +162,28 @@ def rerank(
94162
return scored_items[: min(top_k, len(scored_items))]
95163

96164
elif "data" in data:
165+
# Format: {"data": [{"score": float}, ...]} aligned by list order
97166
rows = data.get("data", [])
167+
# Build a list of scores aligned with our 'documents' (pairs)
98168
score_list = [float(r.get("score", 0.0)) for r in rows]
99169

100170
if len(score_list) < len(graph_results):
101171
score_list += [0.0] * (len(graph_results) - len(score_list))
102172
elif len(score_list) > len(graph_results):
103173
score_list = score_list[: len(graph_results)]
104174

175+
# Map back to original items using 'pairs'
105176
scored_items = list(zip(graph_results, score_list, strict=False))
106177
scored_items.sort(key=lambda x: x[1], reverse=True)
107178
return scored_items[: min(top_k, len(scored_items))]
108179

109180
else:
181+
# Unexpected response schema: return a 0.0-scored fallback of the first top_k valid docs
182+
# Note: we use 'pairs' to keep alignment with valid (string) docs.
110183
return [(item, 0.0) for item in graph_results[:top_k]]
111184

112185
except Exception as e:
186+
# Network error, timeout, JSON decode error, etc.
187+
# Degrade gracefully by returning first top_k valid docs with 0.0 score.
113188
logger.error(f"[HTTPBGEReranker] request failed: {e}")
114189
return [(item, 0.0) for item in graph_results[:top_k]]

0 commit comments

Comments
 (0)