Skip to content

Commit 29fc11d

Browse files
authored
Merge pull request #705 from apecloud/support_rerank
Support rerank
2 parents 1df0d42 + c9b4262 commit 29fc11d

File tree

6 files changed

+198
-143
lines changed

6 files changed

+198
-143
lines changed

aperag/context/context_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from aperag.llm.prompts import DEFAULT_CHINESE_PROMPT_TEMPLATE_V2
2929
from aperag.pipeline.keyword_extractor import IKExtractor
3030
from aperag.query.query import get_packed_answer
31-
from aperag.readers.base_embedding import get_embedding_model, get_rerank_model
31+
from aperag.readers.base_embedding import get_embedding_model
32+
from aperag.rank.reranker import get_rerank_model
3233
from aperag.readers.local_path_embedding import LocalPathEmbedding
3334
from aperag.utils.utils import generate_fulltext_index_name
3435
from aperag.vectorstore.connector import VectorStoreConnectorAdaptor

aperag/pipeline/knowledge_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
Message, Pipeline, DOCUMENT_URLS
3131
from aperag.pipeline.keyword_extractor import IKExtractor
3232
from aperag.query.query import DocumentWithScore, get_packed_answer
33-
from aperag.readers.base_embedding import get_embedding_model, rerank
33+
from aperag.readers.base_embedding import get_embedding_model
34+
from aperag.rank.reranker import rerank
3435
from aperag.source.utils import async_run
3536
from aperag.utils.utils import (
3637
generate_fulltext_index_name,

aperag/rank/reranker.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import json
4+
import os
5+
from abc import ABC, abstractmethod
6+
from threading import Lock
7+
from typing import Any, List
8+
9+
import aiohttp
10+
import torch
11+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
12+
from FlagEmbedding import FlagReranker
13+
14+
from aperag.query.query import DocumentWithScore
15+
from config.settings import (
16+
RERANK_BACKEND,
17+
RERANK_SERVICE_MODEL_UID,
18+
RERANK_SERVICE_MODEL,
19+
RERANK_SERVICE_TOKEN,
20+
RERANK_SERVICE_URL,
21+
)
22+
23+
default_rerank_model_path = "/data/models/bge-reranker-large"
24+
25+
# Mutex and synchronized decorator
26+
mutex = Lock()
27+
rerank_model_cache = {}
28+
29+
def synchronized(func):
30+
def wrapper(*args, **kwargs):
31+
with mutex:
32+
return func(*args, **kwargs)
33+
return wrapper
34+
35+
class Ranker(ABC):
36+
@abstractmethod
37+
async def rank(self, query, results: List[DocumentWithScore]):
38+
pass
39+
40+
class XinferenceRanker(Ranker):
41+
def __init__(self):
42+
self.url = f"{RERANK_SERVICE_URL}/v1/rerank"
43+
self.model_uid = RERANK_SERVICE_MODEL_UID
44+
45+
async def rank(self, query, results: List[DocumentWithScore]):
46+
documents = [doc.text for doc in results]
47+
body = {
48+
"model": self.model_uid,
49+
"documents": documents,
50+
"query": query,
51+
"return_documents": False,
52+
}
53+
async with aiohttp.ClientSession() as session:
54+
async with session.post(self.url, json=body) as resp:
55+
data = await resp.json()
56+
if resp.status != 200:
57+
raise RuntimeError(f"Failed to rerank, detail: {data['detail']}")
58+
indices = [r["index"] for r in data["results"]]
59+
return [results[i] for i in indices]
60+
61+
class JinaRanker(Ranker):
62+
def __init__(self):
63+
self.url = RERANK_SERVICE_URL # "https://api.jina.ai/v1/rerank"
64+
self.model = RERANK_SERVICE_MODEL # "jina-reranker-v2-base-multilingual"
65+
self.auth_token = RERANK_SERVICE_TOKEN # "Bearer YOUR_JINA_TOKEN"
66+
67+
async def rank(self, query, results: List[DocumentWithScore]):
68+
documents = [doc.text for doc in results]
69+
body = {
70+
"model": self.model,
71+
"query": query,
72+
"top_n": len(documents),
73+
"documents": documents,
74+
"return_documents": False
75+
}
76+
headers = {
77+
"Content-Type": "application/json",
78+
"Authorization": f"Bearer {self.auth_token}"
79+
}
80+
async with aiohttp.ClientSession() as session:
81+
async with session.post(self.url, headers=headers, json=body) as resp:
82+
data = await resp.json()
83+
if resp.status != 200:
84+
raise RuntimeError(f"Failed to rerank, detail: {data}")
85+
indices = [r["index"] for r in data["results"]]
86+
return [results[i] for i in indices]
87+
88+
class ContentRatioRanker(Ranker):
89+
def __init__(self, query):
90+
self.query = query
91+
92+
async def rank(self, query, results: List[DocumentWithScore]):
93+
results.sort(key=lambda x: (x.metadata.get("content_ratio", 1), x.score), reverse=True)
94+
return results
95+
96+
class AutoCrossEncoderRanker(Ranker):
97+
def __init__(self):
98+
model_path = os.environ.get("RERANK_MODEL_PATH", default_rerank_model_path)
99+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
100+
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
101+
self.model.eval()
102+
103+
async def rank(self, query, results: List[DocumentWithScore]):
104+
pairs = []
105+
for idx, doc in enumerate(results):
106+
pairs.append((query, doc.text))
107+
doc.rank_before = idx
108+
with torch.no_grad():
109+
inputs = self.tokenizer(
110+
pairs,
111+
padding=True,
112+
truncation=True,
113+
return_tensors='pt',
114+
max_length=512
115+
)
116+
scores = self.model(**inputs, return_dict=True).logits.view(-1,).float()
117+
if isinstance(scores, torch.Tensor):
118+
scores = scores.tolist()
119+
ranked = sorted(zip(scores, results), key=lambda k: k[0], reverse=True)
120+
return [x for _, x in ranked]
121+
122+
class FlagCrossEncoderRanker(Ranker):
123+
def __init__(self):
124+
model_path = os.environ.get("RERANK_MODEL_PATH", default_rerank_model_path)
125+
self.reranker = FlagReranker(model_path)
126+
127+
async def rank(self, query, results: List[DocumentWithScore]):
128+
pairs = []
129+
max_length = 512
130+
for idx, doc in enumerate(results):
131+
pairs.append((query[:max_length], doc.text[:max_length]))
132+
doc.rank_before = idx
133+
if not pairs:
134+
return []
135+
with torch.no_grad():
136+
scores = self.reranker.compute_score(pairs, max_length=max_length)
137+
if isinstance(scores, float):
138+
scores = [scores]
139+
ranked = sorted(zip(scores, results), key=lambda k: k[0], reverse=True)
140+
return [x for _, x in ranked]
141+
142+
class RankerService(Ranker):
143+
def __init__(self):
144+
if RERANK_BACKEND == "xinference":
145+
self.ranker = XinferenceRanker()
146+
elif RERANK_BACKEND == "local":
147+
self.ranker = FlagCrossEncoderRanker()
148+
elif RERANK_BACKEND == "jina":
149+
self.ranker = JinaRanker()
150+
else:
151+
raise Exception("Unsupported backend")
152+
153+
async def rank(self, query, results: List[DocumentWithScore]):
154+
return await self.ranker.rank(query, results)
155+
156+
@synchronized
157+
def get_rerank_model(model_type: str = "bge-reranker-large"):
158+
if model_type in rerank_model_cache:
159+
return rerank_model_cache[model_type]
160+
model = RankerService()
161+
rerank_model_cache[model_type] = model
162+
return model
163+
164+
async def rerank(message, results):
165+
model = get_rerank_model()
166+
return await model.rank(message, results)

0 commit comments

Comments
 (0)