Skip to content

Commit f37ceed

Browse files
committed
fix: code suffix
1 parent 2f2d11a commit f37ceed

File tree

10 files changed

+98
-98
lines changed

10 files changed

+98
-98
lines changed

src/memos/reranker/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from __future__ import annotations
33

44
from abc import ABC, abstractmethod
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING
6+
67

78
if TYPE_CHECKING:
89
from memos.memories.textual.item import TextualMemoryItem
@@ -21,4 +22,4 @@ def rerank(
2122
**kwargs,
2223
) -> list[tuple[TextualMemoryItem, float]]:
2324
"""Return top_k (item, score) sorted by score desc."""
24-
raise NotImplementedError
25+
raise NotImplementedError

src/memos/reranker/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from .cosine_local import CosineLocalReranker
1010
from .http_bge import HTTPBGEReranker
11-
from .noop import NoopReranker
1211
from .http_bge_strategy import HTTPBGERerankerStrategy
12+
from .noop import NoopReranker
1313

1414

1515
if TYPE_CHECKING:

src/memos/reranker/http_bge_strategy.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
import requests
1010

1111
from memos.log import get_logger
12-
13-
from .base import BaseReranker
1412
from memos.reranker.strategies import RerankerStrategyFactory
1513

14+
from .base import BaseReranker
1615

1716

1817
logger = get_logger(__name__)
@@ -151,7 +150,9 @@ def rerank(
151150
if not graph_results:
152151
return []
153152

154-
tracker, original_items, documents = self.reranker_strategy.prepare_documents(query, graph_results, top_k)
153+
tracker, original_items, documents = self.reranker_strategy.prepare_documents(
154+
query, graph_results, top_k
155+
)
155156

156157
logger.info(
157158
f"[HTTPBGEWithSourceReranker] strategy: {self.reranker_strategy}, "
@@ -167,9 +168,7 @@ def rerank(
167168

168169
try:
169170
# Make the HTTP request to the reranker service
170-
resp = requests.post(
171-
self.reranker_url, headers=headers, json=payload, timeout=30
172-
)
171+
resp = requests.post(self.reranker_url, headers=headers, json=payload, timeout=self.timeout)
173172
resp.raise_for_status()
174173
data = resp.json()
175174

@@ -192,13 +191,13 @@ def rerank(
192191
ranked_indices.append(idx)
193192
scores.append(raw_score)
194193
reconstructed_items = self.reranker_strategy.reconstruct_items(
195-
ranked_indices=ranked_indices,
196-
scores=scores,
197-
tracker=tracker,
198-
original_items=original_items,
194+
ranked_indices=ranked_indices,
195+
scores=scores,
196+
tracker=tracker,
197+
original_items=original_items,
199198
top_k=top_k,
200199
graph_results=graph_results,
201-
documents=documents
200+
documents=documents,
202201
)
203202
return reconstructed_items
204203

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .factory import RerankerStrategyFactory
22

33

4-
__all__ = ["RerankerStrategyFactory"]
4+
__all__ = ["RerankerStrategyFactory"]

src/memos/reranker/strategies/base.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from abc import ABC, abstractmethod
2-
from typing import TYPE_CHECKING, Any
2+
from typing import Any
3+
34
from memos.memories.textual.item import TextualMemoryItem
5+
46
from .dialogue_common import DialogueRankingTracker
57

8+
69
class BaseRerankerStrategy(ABC):
710
"""Abstract interface for memory rerankers with concatenation strategy."""
811

@@ -16,21 +19,21 @@ def prepare_documents(
1619
) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]:
1720
"""
1821
Prepare documents for ranking based on the strategy.
19-
22+
2023
Args:
2124
query: The search query
2225
graph_results: List of TextualMemoryItem objects to process
2326
top_k: Maximum number of items to return
2427
**kwargs: Additional strategy-specific parameters
25-
28+
2629
Returns:
27-
tuple[DialogueRankingTracker, dict[str, Any], list[str]]:
30+
tuple[DialogueRankingTracker, dict[str, Any], list[str]]:
2831
- Tracker: DialogueRankingTracker instance
2932
- original_items: Dict mapping memory_id to original TextualMemoryItem
3033
- documents: List of text documents ready for ranking
3134
"""
3235
raise NotImplementedError
33-
36+
3437
@abstractmethod
3538
def reconstruct_items(
3639
self,
@@ -43,16 +46,16 @@ def reconstruct_items(
4346
) -> list[tuple[TextualMemoryItem, float]]:
4447
"""
4548
Reconstruct TextualMemoryItem objects from ranked results.
46-
49+
4750
Args:
4851
ranked_indices: List of indices sorted by relevance
4952
scores: Corresponding relevance scores
5053
tracker: DialogueRankingTracker instance
5154
original_items: Dict mapping memory_id to original TextualMemoryItem
5255
top_k: Maximum number of items to return
5356
**kwargs: Additional strategy-specific parameters
54-
57+
5558
Returns:
5659
List of (reconstructed_memory_item, aggregated_score) tuples
5760
"""
58-
raise NotImplementedError
61+
raise NotImplementedError

src/memos/reranker/strategies/concat_background.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
# memos/reranker/strategies/single_turn.py
22
from __future__ import annotations
3+
34
import re
5+
46
from typing import Any
5-
from collections import defaultdict
6-
from copy import deepcopy
7+
78
from .base import BaseRerankerStrategy
8-
from .dialogue_common import DialogueRankingTracker, strip_memory_tags, extract_content
9+
from .dialogue_common import DialogueRankingTracker
10+
911

1012
_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
1113

14+
1215
class ConcatBackgroundStrategy(BaseRerankerStrategy):
1316
"""
1417
Concat background strategy.
15-
18+
1619
This strategy processes dialogue pairs by concatenating background and
1720
user and assistant messages into single strings for ranking. Each dialogue pair becomes a
1821
separate document for ranking.
19-
"""
22+
"""
2023

2124
def prepare_documents(
2225
self,
@@ -27,33 +30,33 @@ def prepare_documents(
2730
) -> tuple[DialogueRankingTracker, dict[str, Any], list[str]]:
2831
"""
2932
Prepare documents based on single turn concatenation strategy.
30-
33+
3134
Args:
3235
query: The search query
3336
graph_results: List of graph results
3437
top_k: Maximum number of items to return
35-
38+
3639
Returns:
37-
tuple[DialogueRankingTracker, dict[str, Any], list[str]]:
40+
tuple[DialogueRankingTracker, dict[str, Any], list[str]]:
3841
- Tracker: DialogueRankingTracker instance
3942
- original_items: Dict mapping memory_id to original TextualMemoryItem
4043
- documents: List of text documents ready for ranking
4144
"""
42-
45+
4346
original_items = {}
4447
tracker = DialogueRankingTracker()
4548
documents = []
4649
for item in graph_results:
4750
memory = getattr(item, "memory", None)
4851
if isinstance(memory, str):
4952
memory = _TAG1.sub("", memory)
50-
53+
5154
background = ""
5255
if hasattr(item, "metadata") and hasattr(item.metadata, "background"):
5356
background = getattr(item.metadata, "background", "")
5457
if not isinstance(background, str):
5558
background = ""
56-
59+
5760
documents.append(f"{memory}\n{background}")
5861
return tracker, original_items, documents
5962

@@ -68,26 +71,24 @@ def reconstruct_items(
6871
) -> list[tuple[Any, float]]:
6972
"""
7073
Reconstruct TextualMemoryItem objects from ranked dialogue pairs.
71-
74+
7275
Args:
7376
ranked_indices: List of dialogue pair indices sorted by relevance
7477
scores: Corresponding relevance scores
7578
tracker: DialogueRankingTracker instance
7679
original_items: Dict mapping memory_id to original TextualMemoryItem
7780
top_k: Maximum number of items to return
78-
81+
7982
Returns:
8083
List of (reconstructed_memory_item, aggregated_score) tuples
8184
"""
82-
graph_results = kwargs.get("graph_results", None)
83-
documents = kwargs.get("documents", None)
85+
graph_results = kwargs.get("graph_results")
86+
documents = kwargs.get("documents")
8487
reconstructed_items = []
8588
for idx in ranked_indices:
8689
item = graph_results[idx]
8790
item.memory = f"{item.memory}\n{documents[idx]}"
8891
reconstructed_items.append((item, scores[idx]))
89-
92+
9093
reconstructed_items.sort(key=lambda x: x[1], reverse=True)
9194
return reconstructed_items[:top_k]
92-
93-

src/memos/reranker/strategies/dialogue_common.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
import re
4+
45
from typing import Any, Literal
6+
57
from pydantic import BaseModel
6-
from memos.memories.textual.item import SourceMessage
7-
from memos.memories.textual.item import TextualMemoryItem
8+
9+
from memos.memories.textual.item import SourceMessage, TextualMemoryItem
10+
811

912
# Strip a leading "[...]" tag (e.g., "[2025-09-01] ..." or "[meta] ...")
1013
# before sending text to the reranker. This keeps inputs clean and
@@ -17,10 +20,11 @@ def strip_memory_tags(item: TextualMemoryItem) -> str:
1720
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
1821
return memory
1922

23+
2024
def extract_content(msg: dict[str, Any] | str) -> str:
2125
"""Extract content from message, handling both string and dict formats."""
2226
if isinstance(msg, dict):
23-
return msg.get('content', str(msg))
27+
return msg.get("content", str(msg))
2428
if isinstance(msg, SourceMessage):
2529
return msg.content
2630
return str(msg)
@@ -33,7 +37,7 @@ class DialoguePair(BaseModel):
3337
memory_id: str # ID of the source TextualMemoryItem
3438
memory: str
3539
pair_index: int # Index of this pair within the source memory's dialogue
36-
user_msg: str | dict[str, Any] | SourceMessage # User message content
40+
user_msg: str | dict[str, Any] | SourceMessage # User message content
3741
assistant_msg: str | dict[str, Any] | SourceMessage # Assistant message content
3842
combined_text: str # The concatenated text used for ranking
3943
chat_time: str | None = None
@@ -56,14 +60,14 @@ def __init__(self):
5660
self.dialogue_pairs: list[DialoguePair] = []
5761

5862
def add_dialogue_pair(
59-
self,
60-
memory_id: str,
63+
self,
64+
memory_id: str,
6165
pair_index: int,
62-
user_msg: str | dict[str, Any],
66+
user_msg: str | dict[str, Any],
6367
assistant_msg: str | dict[str, Any],
6468
memory: str,
6569
chat_time: str | None = None,
66-
concat_format: Literal["user_assistant", "user_only"] = "user_assistant"
70+
concat_format: Literal["user_assistant", "user_only"] = "user_assistant",
6771
) -> str:
6872
"""Add a dialogue pair and return its unique ID."""
6973
user_content = extract_content(user_msg)
@@ -85,7 +89,7 @@ def add_dialogue_pair(
8589
assistant_msg=assistant_msg,
8690
combined_text=combined_text,
8791
memory=memory,
88-
chat_time=chat_time
92+
chat_time=chat_time,
8993
)
9094

9195
self.dialogue_pairs.append(dialogue_pair)

src/memos/reranker/strategies/factory.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# memos/reranker/factory.py
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Any,ClassVar
5-
from .single_turn import SingleTurnStrategy
4+
from typing import TYPE_CHECKING, Any, ClassVar
5+
66
from .concat_background import ConcatBackgroundStrategy
7+
from .single_turn import SingleTurnStrategy
78
from .singleturn_outmem import SingleTurnOutMemStrategy
89

10+
911
if TYPE_CHECKING:
1012
from .base import BaseRerankerStrategy
1113

12-
class RerankerStrategyFactory():
14+
15+
class RerankerStrategyFactory:
1316
"""Factory class for creating reranker strategy instances."""
1417

1518
backend_to_class: ClassVar[dict[str, Any]] = {
@@ -19,9 +22,7 @@ class RerankerStrategyFactory():
1922
}
2023

2124
@classmethod
22-
def from_config(
23-
cls, config_factory: str = "single_turn"
24-
) -> BaseRerankerStrategy:
25+
def from_config(cls, config_factory: str = "single_turn") -> BaseRerankerStrategy:
2526
if config_factory not in cls.backend_to_class:
2627
raise ValueError(f"Invalid backend: {config_factory}")
2728
strategy_class = cls.backend_to_class[config_factory]

0 commit comments

Comments
 (0)