Skip to content

Commit f799237

Browse files
committed
feat: add reranker Facktory
1 parent 24ed7c0 commit f799237

File tree

8 files changed

+436
-0
lines changed

8 files changed

+436
-0
lines changed

examples/basic_modules/reranker.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import os
2+
import uuid
3+
4+
from dotenv import load_dotenv
5+
6+
from memos import log
7+
from memos.configs.embedder import EmbedderConfigFactory
8+
from memos.configs.reranker import RerankerConfigFactory
9+
from memos.embedders.factory import EmbedderFactory
10+
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
11+
from memos.reranker.factory import RerankerFactory
12+
13+
14+
load_dotenv()
15+
logger = log.get_logger(__name__)
16+
17+
18+
def make_item(text: str) -> TextualMemoryItem:
19+
"""Build a minimal TextualMemoryItem; embedding will be populated later."""
20+
return TextualMemoryItem(
21+
id=str(uuid.uuid4()),
22+
memory=text,
23+
metadata=TreeNodeTextualMemoryMetadata(
24+
user_id=None,
25+
session_id=None,
26+
status="activated",
27+
type="fact",
28+
memory_time="2024-01-01",
29+
source="conversation",
30+
confidence=100.0,
31+
tags=[],
32+
visibility="public",
33+
updated_at="2025-01-01T00:00:00",
34+
memory_type="LongTermMemory",
35+
key="demo_key",
36+
sources=["demo://example"],
37+
embedding=[],
38+
background="demo background...",
39+
),
40+
)
41+
42+
43+
def show_ranked(title: str, ranked: list[tuple[TextualMemoryItem, float]], top_n: int = 5) -> None:
44+
print(f"\n=== {title} ===")
45+
for i, (item, score) in enumerate(ranked[:top_n], start=1):
46+
preview = (item.memory[:80] + "...") if len(item.memory) > 80 else item.memory
47+
print(f"[#{i}] score={score:.6f} | {preview}")
48+
49+
50+
def main():
51+
# -------------------------------
52+
# 1) Build the embedder (real vectors)
53+
# -------------------------------
54+
embedder_cfg = EmbedderConfigFactory.model_validate(
55+
{
56+
"backend": "universal_api",
57+
"config": {
58+
"provider": "openai", # or "azure"
59+
"api_key": os.getenv("OPENAI_API_KEY"),
60+
"model_name_or_path": "text-embedding-3-large",
61+
"base_url": os.getenv("OPENAI_API_BASE"), # optional
62+
},
63+
}
64+
)
65+
embedder = EmbedderFactory.from_config(embedder_cfg)
66+
67+
# -------------------------------
68+
# 2) Prepare query + documents
69+
# -------------------------------
70+
query = "What is the capital of France?"
71+
items = [
72+
make_item("Paris is the capital of France."),
73+
make_item("Berlin is the capital of Germany."),
74+
make_item("The capital of Brazil is Brasilia."),
75+
make_item("Apples and bananas are common fruits."),
76+
make_item("The Eiffel Tower is a famous landmark in Paris."),
77+
]
78+
79+
# -------------------------------
80+
# 3) Embed query + docs with real embeddings
81+
# -------------------------------
82+
texts_to_embed = [query] + [it.memory for it in items]
83+
vectors = embedder.embed(texts_to_embed) # real vectors from your provider/model
84+
query_embedding = vectors[0]
85+
doc_embeddings = vectors[1:]
86+
87+
# attach real embeddings back to items
88+
for it, emb in zip(items, doc_embeddings, strict=False):
89+
it.metadata.embedding = emb
90+
91+
# -------------------------------
92+
# 4) Rerank with cosine_local (uses your real embeddings)
93+
# -------------------------------
94+
cosine_cfg = RerankerConfigFactory.model_validate(
95+
{
96+
"backend": "cosine_local",
97+
"config": {
98+
# structural boosts (optional): uses metadata.background
99+
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
100+
"level_field": "background",
101+
},
102+
}
103+
)
104+
cosine_reranker = RerankerFactory.from_config(cosine_cfg)
105+
106+
ranked_cosine = cosine_reranker.rerank(
107+
query=query,
108+
graph_results=items,
109+
top_k=10,
110+
query_embedding=query_embedding, # required by cosine_local
111+
)
112+
show_ranked("CosineLocal Reranker (with real embeddings)", ranked_cosine, top_n=5)
113+
114+
# -------------------------------
115+
# 5) (Optional) Rerank with HTTP BGE (OpenAI-style /query+documents)
116+
# Requires the service URL; no need for embeddings here
117+
# -------------------------------
118+
bge_url = os.getenv("BGE_RERANKER_URL") # e.g., "http://xxx.x.xxxxx.xxx:xxxx/v1/rerank"
119+
if bge_url:
120+
http_cfg = RerankerConfigFactory.model_validate(
121+
{
122+
"backend": "http_bge",
123+
"config": {
124+
"url": bge_url,
125+
"model": os.getenv("BGE_RERANKER_MODEL", "bge-reranker-v2-m3"),
126+
"timeout": int(os.getenv("BGE_RERANKER_TIMEOUT", "10")),
127+
# "headers_extra": {"Authorization": f"Bearer {os.getenv('BGE_RERANKER_TOKEN')}"}
128+
},
129+
}
130+
)
131+
http_reranker = RerankerFactory.from_config(http_cfg)
132+
133+
ranked_http = http_reranker.rerank(
134+
query=query,
135+
graph_results=items, # uses item.memory internally as documents
136+
top_k=10,
137+
)
138+
show_ranked("HTTP BGE Reranker (OpenAI-style API)", ranked_http, top_n=5)
139+
else:
140+
print("\n[Info] Skipped HTTP BGE scenario because BGE_RERANKER_URL is not set.")
141+
142+
143+
if __name__ == "__main__":
144+
main()

src/memos/configs/reranker.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# memos/configs/reranker.py
2+
from __future__ import annotations
3+
4+
from typing import Any
5+
6+
from pydantic import BaseModel, Field
7+
8+
9+
class RerankerConfigFactory(BaseModel):
10+
"""
11+
{
12+
"backend": "http_bge" | "cosine_local" | "noop",
13+
"config": { ... backend-specific ... }
14+
}
15+
"""
16+
17+
backend: str = Field(..., description="Reranker backend id")
18+
config: dict[str, Any] = Field(default_factory=dict, description="Backend-specific options")

src/memos/reranker/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .factory import RerankerFactory
2+
3+
4+
__all__ = ["RerankerFactory"]

src/memos/reranker/base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# memos/reranker/base.py
2+
from __future__ import annotations
3+
4+
from abc import ABC, abstractmethod
5+
from typing import TYPE_CHECKING
6+
7+
8+
if TYPE_CHECKING:
9+
from memos.memories.textual.item import TextualMemoryItem
10+
11+
12+
class BaseReranker(ABC):
13+
"""Abstract interface for memory rerankers."""
14+
15+
@abstractmethod
16+
def rerank(
17+
self,
18+
query: str,
19+
graph_results: list,
20+
top_k: int,
21+
**kwargs,
22+
) -> list[tuple[TextualMemoryItem, float]]:
23+
"""Return top_k (item, score) sorted by score desc."""
24+
raise NotImplementedError

src/memos/reranker/cosine_local.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# memos/reranker/cosine_local.py
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING
5+
6+
from .base import BaseReranker
7+
8+
9+
if TYPE_CHECKING:
10+
from memos.memories.textual.item import TextualMemoryItem
11+
12+
try:
13+
import numpy as _np
14+
15+
_HAS_NUMPY = True
16+
except Exception:
17+
_HAS_NUMPY = False
18+
19+
20+
def _cosine_one_to_many(q: list[float], m: list[list[float]]) -> list[float]:
21+
"""
22+
Compute cosine similarities between a single vector q and a matrix m (rows are candidates).
23+
"""
24+
if not _HAS_NUMPY:
25+
26+
def dot(a, b): # lowercase per N806
27+
return sum(x * y for x, y in zip(a, b, strict=False))
28+
29+
def norm(a): # lowercase per N806
30+
return sum(x * x for x in a) ** 0.5
31+
32+
qn = norm(q) or 1e-10
33+
sims = []
34+
for v in m:
35+
vn = norm(v) or 1e-10
36+
sims.append(dot(q, v) / (qn * vn))
37+
return sims
38+
39+
qv = _np.asarray(q, dtype=float) # lowercase
40+
mv = _np.asarray(m, dtype=float) # lowercase
41+
qn = _np.linalg.norm(qv) or 1e-10
42+
mn = _np.linalg.norm(mv, axis=1) # lowercase
43+
dots = mv @ qv
44+
return (dots / (mn * qn + 1e-10)).tolist()
45+
46+
47+
class CosineLocalReranker(BaseReranker):
48+
def __init__(
49+
self,
50+
level_weights: dict[str, float] | None = None,
51+
level_field: str = "background",
52+
):
53+
self.level_weights = level_weights or {"topic": 1.0, "concept": 1.0, "fact": 1.0}
54+
self.level_field = level_field
55+
56+
def rerank(
57+
self,
58+
query: str,
59+
graph_results: list,
60+
top_k: int,
61+
**kwargs,
62+
) -> list[tuple[TextualMemoryItem, float]]:
63+
if not graph_results:
64+
return []
65+
66+
query_embedding: list[float] | None = kwargs.get("query_embedding")
67+
if not query_embedding:
68+
return [(item, 0.0) for item in graph_results[:top_k]]
69+
70+
items_with_emb = [
71+
it
72+
for it in graph_results
73+
if getattr(it, "metadata", None) and getattr(it.metadata, "embedding", None)
74+
]
75+
if not items_with_emb:
76+
return [(item, 0.5) for item in graph_results[:top_k]]
77+
78+
cand_vecs = [it.metadata.embedding for it in items_with_emb]
79+
sims = _cosine_one_to_many(query_embedding, cand_vecs)
80+
81+
def get_weight(it: TextualMemoryItem) -> float:
82+
level = getattr(it.metadata, self.level_field, None)
83+
return self.level_weights.get(level, 1.0)
84+
85+
weighted = [sim * get_weight(it) for sim, it in zip(sims, items_with_emb, strict=False)]
86+
scored_pairs = list(zip(items_with_emb, weighted, strict=False))
87+
scored_pairs.sort(key=lambda x: x[1], reverse=True)
88+
89+
top_items = scored_pairs[:top_k]
90+
if len(top_items) < top_k:
91+
chosen = {it.id for it, _ in top_items}
92+
remain = [(it, -1.0) for it in graph_results if it.id not in chosen]
93+
top_items.extend(remain[: top_k - len(top_items)])
94+
95+
return top_items

src/memos/reranker/factory.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# memos/reranker/factory.py
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING, Any
5+
6+
from .cosine_local import CosineLocalReranker
7+
from .http_bge import HTTPBGEReranker
8+
from .noop import NoopReranker
9+
10+
11+
if TYPE_CHECKING:
12+
from memos.configs.reranker import RerankerConfigFactory
13+
14+
from .base import BaseReranker
15+
16+
17+
class RerankerFactory:
18+
@staticmethod
19+
def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
20+
if not cfg:
21+
return None
22+
23+
backend = (cfg.backend or "").lower()
24+
c: dict[str, Any] = cfg.config or {}
25+
26+
if backend in {"http_bge", "bge"}:
27+
return HTTPBGEReranker(
28+
reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"),
29+
model=c.get("model", "bge-reranker-v2-m3"),
30+
timeout=int(c.get("timeout", 10)),
31+
headers_extra=c.get("headers_extra"),
32+
)
33+
34+
if backend in {"cosine_local", "cosine"}:
35+
return CosineLocalReranker(
36+
level_weights=c.get("level_weights"),
37+
level_field=c.get("level_field", "background"),
38+
)
39+
40+
if backend in {"noop", "none", "disabled"}:
41+
return NoopReranker()
42+
43+
raise ValueError(f"Unknown reranker backend: {cfg.backend}")

0 commit comments

Comments
 (0)