Skip to content

Commit a2c2a39

Browse files
committed
feat:fix conflict
1 parent bf1d4fc commit a2c2a39

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/memos/reranker/concat.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import List, Dict, Any, Union, Tuple
1+
from typing import Any
22

33

44
def process_source(
5-
items: List[Tuple[Any, Union[str, Dict[str, Any], List[Any]]]] | None = None,
6-
recent_num: int = 3
7-
) -> str:
5+
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3
6+
) -> str:
87
"""
98
Args:
109
items: List of tuples where each tuple contains (memory, source).
@@ -13,6 +12,8 @@ def process_source(
1312
Returns:
1413
str: Concatenated source.
1514
"""
15+
if items is None:
16+
items = []
1617
concat_data = []
1718
for item in items:
1819
memory, source = item
@@ -24,7 +25,7 @@ def process_source(
2425

2526
def concat_original_source(
2627
graph_results: list,
27-
merge_field: List[str]=["sources"],
28+
merge_field: list[str] | None = None,
2829
) -> list[str]:
2930
"""
3031
Merge memory items with original dialogue.
@@ -34,6 +35,8 @@ def concat_original_source(
3435
Returns:
3536
list[str]: List of memory and concat orginal memory.
3637
"""
38+
if merge_field is None:
39+
merge_field = ["sources"]
3740
documents = []
3841
for item in graph_results:
3942
memory = getattr(item, "memory", "")
@@ -43,4 +46,4 @@ def concat_original_source(
4346
sources.append((memory, source))
4447
concat_string = process_source(sources)
4548
documents.append(concat_string)
46-
return documents
49+
return documents

src/memos/reranker/http_bge.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
import requests
99

10+
from memos.log import get_logger
11+
1012
from .base import BaseReranker
1113
from .concat import concat_original_source
12-
from memos.log import get_logger
14+
1315

1416
logger = get_logger(__name__)
1517

@@ -32,7 +34,7 @@ def __init__(
3234
model: str = "bge-reranker-v2-m3",
3335
timeout: int = 10,
3436
headers_extra: dict | None = None,
35-
concat_source: List[str]=["sources"],
37+
concat_source: list[str] | None = None,
3638
):
3739
if not reranker_url:
3840
raise ValueError("reranker_url must not be empty")
@@ -41,7 +43,7 @@ def __init__(
4143
self.model = model
4244
self.timeout = timeout
4345
self.headers_extra = headers_extra or {}
44-
self.concat_source = concat_source
46+
self.concat_source = concat_source or ["sources"]
4547

4648
def rerank(
4749
self,

0 commit comments

Comments
 (0)