Skip to content

Commit d8f041f

Browse files
committed
feat: refactor rerank
1 parent 1df0d42 commit d8f041f

File tree

5 files changed

+188
-143
lines changed

5 files changed

+188
-143
lines changed

aperag/context/context_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from aperag.llm.prompts import DEFAULT_CHINESE_PROMPT_TEMPLATE_V2
2929
from aperag.pipeline.keyword_extractor import IKExtractor
3030
from aperag.query.query import get_packed_answer
31-
from aperag.readers.base_embedding import get_embedding_model, get_rerank_model
31+
from aperag.readers.base_embedding import get_embedding_model
32+
from aperag.rank.reranker import get_rerank_model
3233
from aperag.readers.local_path_embedding import LocalPathEmbedding
3334
from aperag.utils.utils import generate_fulltext_index_name
3435
from aperag.vectorstore.connector import VectorStoreConnectorAdaptor

aperag/pipeline/knowledge_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
Message, Pipeline, DOCUMENT_URLS
3131
from aperag.pipeline.keyword_extractor import IKExtractor
3232
from aperag.query.query import DocumentWithScore, get_packed_answer
33-
from aperag.readers.base_embedding import get_embedding_model, rerank
33+
from aperag.readers.base_embedding import get_embedding_model
34+
from aperag.rank.reranker import rerank
3435
from aperag.source.utils import async_run
3536
from aperag.utils.utils import (
3637
generate_fulltext_index_name,

aperag/rank/reranker.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import json
4+
import os
5+
from abc import ABC, abstractmethod
6+
from threading import Lock
7+
from typing import Any, List
8+
9+
import aiohttp
10+
import torch
11+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
12+
from FlagEmbedding import FlagReranker
13+
14+
from aperag.query.query import DocumentWithScore
15+
from config.settings import (
16+
RERANK_BACKEND,
17+
RERANK_SERVICE_MODEL_UID,
18+
RERANK_SERVICE_URL,
19+
)
20+
21+
default_rerank_model_path = "/data/models/bge-reranker-large"
22+
23+
# Mutex and synchronized decorator (copied for self-containment as requested)
24+
mutex = Lock()
25+
rerank_model_cache = {}
26+
27+
28+
# synchronized decorator
29+
def synchronized(func):
30+
def wrapper(*args, **kwargs):
31+
with mutex:
32+
return func(*args, **kwargs)
33+
34+
return wrapper
35+
36+
37+
class Ranker(ABC):
38+
39+
@abstractmethod
40+
async def rank(self, query, results: List[DocumentWithScore]):
41+
pass
42+
43+
44+
class RankerService(Ranker):
45+
def __init__(self):
46+
if RERANK_BACKEND == "xinference":
47+
self.ranker = XinferenceRanker()
48+
elif RERANK_BACKEND == "local":
49+
self.ranker = FlagCrossEncoderRanker()
50+
else:
51+
raise Exception(
52+
"Unsupported embedding backend") # Note: Error message refers to embedding backend, might be a typo in original code
53+
54+
async def rank(self, query, results: List[DocumentWithScore]):
55+
return await self.ranker.rank(query, results)
56+
57+
58+
class XinferenceRanker(Ranker):
59+
def __init__(self):
60+
self.url = f"{RERANK_SERVICE_URL}/v1/rerank"
61+
self.model_uid = RERANK_SERVICE_MODEL_UID
62+
63+
async def rank(self, query, results: List[DocumentWithScore]):
64+
documents = [document.text for document in results]
65+
request_body = {
66+
"model": self.model_uid,
67+
"documents": documents,
68+
"query": query,
69+
"return_documents": False,
70+
}
71+
async with aiohttp.ClientSession() as session:
72+
async with session.post(self.url, json=request_body) as response:
73+
response_data = await response.json()
74+
if response.status != 200:
75+
raise RuntimeError(f"Failed to rerank documents, detail: {response_data['detail']}")
76+
indices = [response['index'] for response in response_data['results']]
77+
return [results[index] for index in indices]
78+
79+
80+
class ContentRatioRanker(Ranker):
81+
def __init__(self,
82+
query): # Note: query passed in constructor but not used in rank method? Original code behavior preserved.
83+
self.query = query
84+
85+
async def rank(self, query, results: List[DocumentWithScore]):
86+
results.sort(key=lambda x: (x.metadata.get("content_ratio", 1), x.score), reverse=True)
87+
return results
88+
89+
90+
class AutoCrossEncoderRanker(Ranker):
91+
def __init__(self):
92+
model_path = os.environ.get("RERANK_MODEL_PATH", default_rerank_model_path)
93+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
94+
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
95+
self.model.eval()
96+
97+
async def rank(self, query, results: List[DocumentWithScore]):
98+
pairs = []
99+
for idx, result in enumerate(results):
100+
pairs.append((query, result.text))
101+
result.rank_before = idx
102+
103+
with torch.no_grad():
104+
inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
105+
scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
106+
# Ensure scores is iterable even if only one result
107+
if not isinstance(scores, (list, torch.Tensor)) or (
108+
isinstance(scores, torch.Tensor) and scores.numel() == 1 and len(results) == 1):
109+
scores = [scores.item()] if isinstance(scores, torch.Tensor) else [scores]
110+
elif isinstance(scores, torch.Tensor):
111+
scores = scores.tolist()
112+
113+
results = [x for _, x in sorted(zip(scores, results), key=lambda k: k[0], reverse=True)]
114+
115+
return results
116+
117+
118+
class FlagCrossEncoderRanker(Ranker):
119+
def __init__(self):
120+
model_path = os.environ.get("RERANK_MODEL_PATH", default_rerank_model_path)
121+
# self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
122+
self.reranker = FlagReranker(model_path) # use fp16 can speed up computing
123+
124+
async def rank(self, query, results: List[DocumentWithScore]):
125+
pairs = []
126+
max_length = 512
127+
for idx, result in enumerate(results):
128+
pairs.append((query[:max_length], result.text[:max_length]))
129+
result.rank_before = idx
130+
131+
if not pairs:
132+
return []
133+
134+
with torch.no_grad():
135+
scores = self.reranker.compute_score(pairs, max_length=max_length)
136+
# FlagReranker returns a single float if only one pair is given
137+
if isinstance(scores, float):
138+
scores = [scores]
139+
results = [x for _, x in sorted(zip(scores, results), key=lambda k: k[0], reverse=True)]
140+
141+
return results
142+
143+
144+
@synchronized
145+
def get_rerank_model(model_type: str = "bge-reranker-large"):
146+
# self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
147+
# Note: model_type parameter is not currently used to select different RankerService logic, but kept for signature consistency.
148+
if model_type in rerank_model_cache:
149+
return rerank_model_cache[model_type]
150+
model = RankerService()
151+
rerank_model_cache[model_type] = model
152+
return model
153+
154+
155+
async def rerank(message, results):
156+
model = get_rerank_model()
157+
results = await model.rank(message, results)
158+
return results

0 commit comments

Comments
 (0)