Skip to content

Commit 778c3b4

Browse files
authored
feat: chat bot api, add reranker filter; fix pydantic bug (#303)
* fix: add safe guard when parsing node memory * feat: add filter as a parameter in tree-text searcher * feat: add filter for user and long-term memory * feat: add filter in working memory * add filter in task-parser * feat: only mix-retrieve for vector-recall; TODO: mix reranker * feat: add 'session_id' as an optional parameter for product api * feat: api 1.0 finish * maintain: update gitignore * maintain: update gitignore * feat: add 'type' in TextualMemory Sources * feat: add annotation to item * fix: add session_id to product add * fix: test * feat: [WIP] add filter in reranker * fix: bug in recall * feat: finish search filter in reranker * fix: product router pydantic errir
1 parent c688ead commit 778c3b4

File tree

3 files changed

+160
-9
lines changed

3 files changed

+160
-9
lines changed

examples/basic_modules/reranker.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ def main():
8888
for it, emb in zip(items, doc_embeddings, strict=False):
8989
it.metadata.embedding = emb
9090

91+
items[0].metadata.user_id = "u_123"
92+
items[0].metadata.session_id = "s_abc"
93+
items[0].metadata.tags = [*items[0].metadata.tags, "paris"]
94+
95+
items[1].metadata.user_id = "u_124"
96+
items[1].metadata.session_id = "s_xyz"
97+
items[1].metadata.tags = [*items[1].metadata.tags, "germany"]
98+
items[2].metadata.user_id = "u_125"
99+
items[2].metadata.session_id = "s_ss3"
100+
items[3].metadata.user_id = "u_126"
101+
items[3].metadata.session_id = "s_ss4"
102+
items[4].metadata.user_id = "u_127"
103+
items[4].metadata.session_id = "s_ss5"
104+
91105
# -------------------------------
92106
# 4) Rerank with cosine_local (uses your real embeddings)
93107
# -------------------------------
@@ -124,7 +138,7 @@ def main():
124138
"url": bge_url,
125139
"model": os.getenv("BGE_RERANKER_MODEL", "bge-reranker-v2-m3"),
126140
"timeout": int(os.getenv("BGE_RERANKER_TIMEOUT", "10")),
127-
# "headers_extra": {"Authorization": f"Bearer {os.getenv('BGE_RERANKER_TOKEN')}"}
141+
"boost_weights": {"user_id": 0.5, "tags": 0.2},
128142
},
129143
}
130144
)
@@ -136,6 +150,20 @@ def main():
136150
top_k=10,
137151
)
138152
show_ranked("HTTP BGE Reranker (OpenAI-style API)", ranked_http, top_n=5)
153+
154+
# --- NEW: search_filter with rerank ---
155+
# hit rule:
156+
# - user_id == "u_123" → score * (1 + 0.5) = 1.5
157+
# - tags including "paris" → score * (1 + 0.2) = 1.2
158+
# - project_id(not exist) → warning unrelated with score
159+
search_filter = {"session_id": "germany", "tags": "germany", "project_id": "demo-p1"}
160+
ranked_http_boosted = http_reranker.rerank(
161+
query=query,
162+
graph_results=items,
163+
top_k=10,
164+
search_filter=search_filter,
165+
)
166+
show_ranked("HTTP BGE Reranker (with search_filter boosts)", ranked_http_boosted, top_n=5)
139167
else:
140168
print("\n[Info] Skipped HTTP BGE scenario because BGE_RERANKER_URL is not set.")
141169

src/memos/reranker/http_bge.py

Lines changed: 128 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import re
55

6-
from typing import TYPE_CHECKING
6+
from collections.abc import Iterable
7+
from typing import TYPE_CHECKING, Any
78

89
import requests
910

@@ -23,6 +24,28 @@
2324
# before sending text to the reranker. This keeps inputs clean and
2425
# avoids misleading the model with bracketed prefixes.
2526
_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
27+
DEFAULT_BOOST_WEIGHTS = {"user_id": 0.5, "tags": 0.2, "session_id": 0.3}
28+
29+
30+
def _value_matches(item_value: Any, wanted: Any) -> bool:
31+
"""
32+
Generic matching:
33+
- if item_value is list/tuple/set: check membership (any match if wanted is iterable)
34+
- else: equality (any match if wanted is iterable)
35+
"""
36+
37+
def _iterable(x):
38+
# exclude strings from "iterable"
39+
return isinstance(x, Iterable) and not isinstance(x, str | bytes)
40+
41+
if _iterable(item_value):
42+
if _iterable(wanted):
43+
return any(w in item_value for w in wanted)
44+
return wanted in item_value
45+
else:
46+
if _iterable(wanted):
47+
return any(item_value == w for w in wanted)
48+
return item_value == wanted
2649

2750

2851
class HTTPBGEReranker(BaseReranker):
@@ -58,6 +81,9 @@ def __init__(
5881
timeout: int = 10,
5982
headers_extra: dict | None = None,
6083
rerank_source: list[str] | None = None,
84+
boost_weights: dict[str, float] | None = None,
85+
boost_default: float = 0.0,
86+
warn_unknown_filter_keys: bool = True,
6187
**kwargs,
6288
):
6389
"""
@@ -83,6 +109,15 @@ def __init__(
83109
self.headers_extra = headers_extra or {}
84110
self.concat_source = rerank_source
85111

112+
self.boost_weights = (
113+
DEFAULT_BOOST_WEIGHTS.copy()
114+
if boost_weights is None
115+
else {k: float(v) for k, v in boost_weights.items()}
116+
)
117+
self.boost_default = float(boost_default)
118+
self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys)
119+
self._warned_missing_keys: set[str] = set()
120+
86121
def rerank(
87122
self,
88123
query: str,
@@ -117,7 +152,6 @@ def rerank(
117152
# Build a mapping from "payload docs index" -> "original graph_results index"
118153
# Only include items that have a non-empty string memory. This ensures that
119154
# any index returned by the server can be mapped back correctly.
120-
documents = []
121155
if self.concat_source:
122156
documents = concat_original_source(graph_results, self.concat_source)
123157
else:
@@ -155,8 +189,11 @@ def rerank(
155189
# The returned index refers to 'documents' (i.e., our 'pairs' order),
156190
# so we must map it back to the original graph_results index.
157191
if isinstance(idx, int) and 0 <= idx < len(graph_results):
158-
score = float(r.get("relevance_score", r.get("score", 0.0)))
159-
scored_items.append((graph_results[idx], score))
192+
raw_score = float(r.get("relevance_score", r.get("score", 0.0)))
193+
item = graph_results[idx]
194+
# generic boost
195+
score = self._apply_boost_generic(item, raw_score, search_filter)
196+
scored_items.append((item, score))
160197

161198
scored_items.sort(key=lambda x: x[1], reverse=True)
162199
return scored_items[: min(top_k, len(scored_items))]
@@ -172,8 +209,10 @@ def rerank(
172209
elif len(score_list) > len(graph_results):
173210
score_list = score_list[: len(graph_results)]
174211

175-
# Map back to original items using 'pairs'
176-
scored_items = list(zip(graph_results, score_list, strict=False))
212+
scored_items = []
213+
for item, raw_score in zip(graph_results, score_list, strict=False):
214+
score = self._apply_boost_generic(item, raw_score, search_filter)
215+
scored_items.append((item, score))
177216
scored_items.sort(key=lambda x: x[1], reverse=True)
178217
return scored_items[: min(top_k, len(scored_items))]
179218

@@ -187,3 +226,86 @@ def rerank(
187226
# Degrade gracefully by returning first top_k valid docs with 0.0 score.
188227
logger.error(f"[HTTPBGEReranker] request failed: {e}")
189228
return [(item, 0.0) for item in graph_results[:top_k]]
229+
230+
def _get_attr_or_key(self, obj: Any, key: str) -> Any:
231+
"""
232+
Resolve `key` on `obj` with one-level fallback into `obj.metadata`.
233+
234+
Priority:
235+
1) obj.<key>
236+
2) obj[key]
237+
3) obj.metadata.<key>
238+
4) obj.metadata[key]
239+
"""
240+
if obj is None:
241+
return None
242+
243+
# support input like "metadata.user_id"
244+
if "." in key:
245+
head, tail = key.split(".", 1)
246+
base = self._get_attr_or_key(obj, head)
247+
return self._get_attr_or_key(base, tail)
248+
249+
def _resolve(o: Any, k: str):
250+
if o is None:
251+
return None
252+
v = getattr(o, k, None)
253+
if v is not None:
254+
return v
255+
if hasattr(o, "get"):
256+
try:
257+
return o.get(k)
258+
except Exception:
259+
return None
260+
return None
261+
262+
# 1) find in obj
263+
v = _resolve(obj, key)
264+
if v is not None:
265+
return v
266+
267+
# 2) find in obj.metadata
268+
meta = _resolve(obj, "metadata")
269+
if meta is not None:
270+
return _resolve(meta, key)
271+
272+
return None
273+
274+
def _apply_boost_generic(
275+
self,
276+
item: TextualMemoryItem,
277+
base_score: float,
278+
search_filter: dict | None,
279+
) -> float:
280+
"""
281+
Multiply base_score by (1 + weight) for each matching key in search_filter.
282+
- key resolution: self._get_attr_or_key(item, key)
283+
- weight = boost_weights.get(key, self.boost_default)
284+
- unknown key -> one-time warning
285+
"""
286+
if not search_filter:
287+
return base_score
288+
289+
score = float(base_score)
290+
291+
for key, wanted in search_filter.items():
292+
# _get_attr_or_key automatically find key in item and
293+
# item.metadata ("metadata.user_id" supported)
294+
resolved = self._get_attr_or_key(item, key)
295+
296+
if resolved is None:
297+
if self.warn_unknown_filter_keys and key not in self._warned_missing_keys:
298+
logger.warning(
299+
"[HTTPBGEReranker] search_filter key '%s' not found on TextualMemoryItem or metadata",
300+
key,
301+
)
302+
self._warned_missing_keys.add(key)
303+
continue
304+
305+
if _value_matches(resolved, wanted):
306+
w = float(self.boost_weights.get(key, self.boost_default))
307+
if w != 0.0:
308+
score *= 1.0 + w
309+
score = min(max(0.0, score), 1.0)
310+
311+
return score

src/memos/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ class MessageDict(TypedDict, total=False):
2727

2828
role: MessageRole
2929
content: str
30-
chat_time: str # Optional timestamp for the message, format is not restricted, it can be any vague or precise time string.
31-
message_id: str # Optional unique identifier for the message
30+
chat_time: str | None # Optional timestamp for the message, format is not
31+
# restricted, it can be any vague or precise time string.
32+
message_id: str | None # Optional unique identifier for the message
3233

3334

3435
# Message collections

0 commit comments

Comments
 (0)