Skip to content

Commit 0e9506d

Browse files
committed
feat: refactor rerank
1 parent d8f041f commit 0e9506d

File tree

2 files changed

+73
-63
lines changed

2 files changed

+73
-63
lines changed

aperag/rank/reranker.py

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,78 +15,84 @@
1515
from config.settings import (
1616
RERANK_BACKEND,
1717
RERANK_SERVICE_MODEL_UID,
18+
RERANK_SERVICE_MODEL,
19+
RERANK_SERVICE_TOKEN,
1820
RERANK_SERVICE_URL,
1921
)
2022

2123
default_rerank_model_path = "/data/models/bge-reranker-large"
2224

23-
# Mutex and synchronized decorator (copied for self-containment as requested)
25+
# Mutex and synchronized decorator
2426
mutex = Lock()
2527
rerank_model_cache = {}
2628

27-
28-
# synchronized decorator
2929
def synchronized(func):
3030
def wrapper(*args, **kwargs):
3131
with mutex:
3232
return func(*args, **kwargs)
33-
3433
return wrapper
3534

36-
3735
class Ranker(ABC):
38-
3936
@abstractmethod
4037
async def rank(self, query, results: List[DocumentWithScore]):
4138
pass
4239

43-
44-
class RankerService(Ranker):
45-
def __init__(self):
46-
if RERANK_BACKEND == "xinference":
47-
self.ranker = XinferenceRanker()
48-
elif RERANK_BACKEND == "local":
49-
self.ranker = FlagCrossEncoderRanker()
50-
else:
51-
raise Exception(
52-
"Unsupported embedding backend") # Note: Error message refers to embedding backend, might be a typo in original code
53-
54-
async def rank(self, query, results: List[DocumentWithScore]):
55-
return await self.ranker.rank(query, results)
56-
57-
5840
class XinferenceRanker(Ranker):
5941
def __init__(self):
6042
self.url = f"{RERANK_SERVICE_URL}/v1/rerank"
6143
self.model_uid = RERANK_SERVICE_MODEL_UID
6244

6345
async def rank(self, query, results: List[DocumentWithScore]):
64-
documents = [document.text for document in results]
65-
request_body = {
46+
documents = [doc.text for doc in results]
47+
body = {
6648
"model": self.model_uid,
6749
"documents": documents,
6850
"query": query,
6951
"return_documents": False,
7052
}
7153
async with aiohttp.ClientSession() as session:
72-
async with session.post(self.url, json=request_body) as response:
73-
response_data = await response.json()
74-
if response.status != 200:
75-
raise RuntimeError(f"Failed to rerank documents, detail: {response_data['detail']}")
76-
indices = [response['index'] for response in response_data['results']]
77-
return [results[index] for index in indices]
54+
async with session.post(self.url, json=body) as resp:
55+
data = await resp.json()
56+
if resp.status != 200:
57+
raise RuntimeError(f"Failed to rerank, detail: {data['detail']}")
58+
indices = [r["index"] for r in data["results"]]
59+
return [results[i] for i in indices]
60+
61+
class JinaRanker(Ranker):
62+
def __init__(self):
63+
self.url = RERANK_SERVICE_URL # "https://api.jina.ai/v1/rerank"
64+
self.model = RERANK_SERVICE_MODEL # "jina-reranker-v2-base-multilingual"
65+
self.auth_token = RERANK_SERVICE_TOKEN # "Bearer YOUR_JINA_TOKEN"
7866

67+
async def rank(self, query, results: List[DocumentWithScore]):
68+
documents = [doc.text for doc in results]
69+
body = {
70+
"model": self.model,
71+
"query": query,
72+
"top_n": len(documents),
73+
"documents": documents,
74+
"return_documents": False
75+
}
76+
headers = {
77+
"Content-Type": "application/json",
78+
"Authorization": f"Bearer ${self.auth_token}"
79+
}
80+
async with aiohttp.ClientSession() as session:
81+
async with session.post(self.url, headers=headers, json=body) as resp:
82+
data = await resp.json()
83+
if resp.status != 200:
84+
raise RuntimeError(f"Failed to rerank, detail: {data}")
85+
indices = [r["index"] for r in data["results"]]
86+
return [results[i] for i in indices]
7987

8088
class ContentRatioRanker(Ranker):
81-
def __init__(self,
82-
query): # Note: query passed in constructor but not used in rank method? Original code behavior preserved.
89+
def __init__(self, query):
8390
self.query = query
8491

8592
async def rank(self, query, results: List[DocumentWithScore]):
8693
results.sort(key=lambda x: (x.metadata.get("content_ratio", 1), x.score), reverse=True)
8794
return results
8895

89-
9096
class AutoCrossEncoderRanker(Ranker):
9197
def __init__(self):
9298
model_path = os.environ.get("RERANK_MODEL_PATH", default_rerank_model_path)
@@ -96,63 +102,65 @@ def __init__(self):
96102

97103
async def rank(self, query, results: List[DocumentWithScore]):
98104
pairs = []
99-
for idx, result in enumerate(results):
100-
pairs.append((query, result.text))
101-
result.rank_before = idx
102-
105+
for idx, doc in enumerate(results):
106+
pairs.append((query, doc.text))
107+
doc.rank_before = idx
103108
with torch.no_grad():
104-
inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
105-
scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
106-
# Ensure scores is iterable even if only one result
107-
if not isinstance(scores, (list, torch.Tensor)) or (
108-
isinstance(scores, torch.Tensor) and scores.numel() == 1 and len(results) == 1):
109-
scores = [scores.item()] if isinstance(scores, torch.Tensor) else [scores]
110-
elif isinstance(scores, torch.Tensor):
109+
inputs = self.tokenizer(
110+
pairs,
111+
padding=True,
112+
truncation=True,
113+
return_tensors='pt',
114+
max_length=512
115+
)
116+
scores = self.model(**inputs, return_dict=True).logits.view(-1,).float()
117+
if isinstance(scores, torch.Tensor):
111118
scores = scores.tolist()
112-
113-
results = [x for _, x in sorted(zip(scores, results), key=lambda k: k[0], reverse=True)]
114-
115-
return results
116-
119+
ranked = sorted(zip(scores, results), key=lambda k: k[0], reverse=True)
120+
return [x for _, x in ranked]
117121

118122
class FlagCrossEncoderRanker(Ranker):
119123
def __init__(self):
120124
model_path = os.environ.get("RERANK_MODEL_PATH", default_rerank_model_path)
121-
# self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
122-
self.reranker = FlagReranker(model_path) # use fp16 can speed up computing
125+
self.reranker = FlagReranker(model_path)
123126

124127
async def rank(self, query, results: List[DocumentWithScore]):
125128
pairs = []
126129
max_length = 512
127-
for idx, result in enumerate(results):
128-
pairs.append((query[:max_length], result.text[:max_length]))
129-
result.rank_before = idx
130-
130+
for idx, doc in enumerate(results):
131+
pairs.append((query[:max_length], doc.text[:max_length]))
132+
doc.rank_before = idx
131133
if not pairs:
132134
return []
133-
134135
with torch.no_grad():
135136
scores = self.reranker.compute_score(pairs, max_length=max_length)
136-
# FlagReranker returns a single float if only one pair is given
137137
if isinstance(scores, float):
138138
scores = [scores]
139-
results = [x for _, x in sorted(zip(scores, results), key=lambda k: k[0], reverse=True)]
139+
ranked = sorted(zip(scores, results), key=lambda k: k[0], reverse=True)
140+
return [x for _, x in ranked]
140141

141-
return results
142+
class RankerService(Ranker):
143+
def __init__(self):
144+
if RERANK_BACKEND == "xinference":
145+
self.ranker = XinferenceRanker()
146+
elif RERANK_BACKEND == "local":
147+
self.ranker = FlagCrossEncoderRanker()
148+
elif RERANK_BACKEND == "jina":
149+
self.ranker = JinaRanker()
150+
else:
151+
raise Exception("Unsupported backend")
142152

153+
async def rank(self, query, results: List[DocumentWithScore]):
154+
return await self.ranker.rank(query, results)
143155

144156
@synchronized
145157
def get_rerank_model(model_type: str = "bge-reranker-large"):
146-
# self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
147-
# Note: model_type parameter is not currently used to select different RankerService logic, but kept for signature consistency.
148158
if model_type in rerank_model_cache:
149159
return rerank_model_cache[model_type]
150160
model = RankerService()
151161
rerank_model_cache[model_type] = model
152162
return model
153163

154-
155164
async def rerank(message, results):
156165
model = get_rerank_model()
157-
results = await model.rank(message, results)
158-
return results
166+
return await model.rank(message, results)

config/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@
258258

259259
RERANK_BACKEND = env.str("RERANK_BACKEND", default="local")
260260
RERANK_SERVICE_URL = env.str("RERANK_SERVICE_URL", default="http://localhost:9997")
261+
RERANK_SERVICE_MODEL = env.str("RERANK_SERVICE_MODEL")
262+
RERANK_SERVICE_TOKEN = env.str("RERANK_SERVICE_TOKEN")
261263
# xinference only needs model_uid, doesn't need model name
262264
RERANK_SERVICE_MODEL_UID = env.str("RERANK_SERVICE_MODEL_UID", default="")
263265

0 commit comments

Comments
 (0)