1+ #!/usr/bin/env python3
2+ # -*- coding: utf-8 -*-
3+ import json
4+ import os
5+ from abc import ABC , abstractmethod
6+ from threading import Lock
7+ from typing import Any , List
8+
9+ import aiohttp
10+ import torch
11+ from transformers import AutoModelForSequenceClassification , AutoTokenizer
12+ from FlagEmbedding import FlagReranker
13+
14+ from aperag .query .query import DocumentWithScore
15+ from config .settings import (
16+ RERANK_BACKEND ,
17+ RERANK_SERVICE_MODEL_UID ,
18+ RERANK_SERVICE_MODEL ,
19+ RERANK_SERVICE_TOKEN ,
20+ RERANK_SERVICE_URL ,
21+ )
22+
23+ default_rerank_model_path = "/data/models/bge-reranker-large"
24+
25+ # Mutex and synchronized decorator
26+ mutex = Lock ()
27+ rerank_model_cache = {}
28+
29+ def synchronized (func ):
30+ def wrapper (* args , ** kwargs ):
31+ with mutex :
32+ return func (* args , ** kwargs )
33+ return wrapper
34+
35+ class Ranker (ABC ):
36+ @abstractmethod
37+ async def rank (self , query , results : List [DocumentWithScore ]):
38+ pass
39+
40+ class XinferenceRanker (Ranker ):
41+ def __init__ (self ):
42+ self .url = f"{ RERANK_SERVICE_URL } /v1/rerank"
43+ self .model_uid = RERANK_SERVICE_MODEL_UID
44+
45+ async def rank (self , query , results : List [DocumentWithScore ]):
46+ documents = [doc .text for doc in results ]
47+ body = {
48+ "model" : self .model_uid ,
49+ "documents" : documents ,
50+ "query" : query ,
51+ "return_documents" : False ,
52+ }
53+ async with aiohttp .ClientSession () as session :
54+ async with session .post (self .url , json = body ) as resp :
55+ data = await resp .json ()
56+ if resp .status != 200 :
57+ raise RuntimeError (f"Failed to rerank, detail: { data ['detail' ]} " )
58+ indices = [r ["index" ] for r in data ["results" ]]
59+ return [results [i ] for i in indices ]
60+
61+ class JinaRanker (Ranker ):
62+ def __init__ (self ):
63+ self .url = RERANK_SERVICE_URL # "https://api.jina.ai/v1/rerank"
64+ self .model = RERANK_SERVICE_MODEL # "jina-reranker-v2-base-multilingual"
65+ self .auth_token = RERANK_SERVICE_TOKEN # "Bearer YOUR_JINA_TOKEN"
66+
67+ async def rank (self , query , results : List [DocumentWithScore ]):
68+ documents = [doc .text for doc in results ]
69+ body = {
70+ "model" : self .model ,
71+ "query" : query ,
72+ "top_n" : len (documents ),
73+ "documents" : documents ,
74+ "return_documents" : False
75+ }
76+ headers = {
77+ "Content-Type" : "application/json" ,
78+ "Authorization" : f"Bearer { self .auth_token } "
79+ }
80+ async with aiohttp .ClientSession () as session :
81+ async with session .post (self .url , headers = headers , json = body ) as resp :
82+ data = await resp .json ()
83+ if resp .status != 200 :
84+ raise RuntimeError (f"Failed to rerank, detail: { data } " )
85+ indices = [r ["index" ] for r in data ["results" ]]
86+ return [results [i ] for i in indices ]
87+
88+ class ContentRatioRanker (Ranker ):
89+ def __init__ (self , query ):
90+ self .query = query
91+
92+ async def rank (self , query , results : List [DocumentWithScore ]):
93+ results .sort (key = lambda x : (x .metadata .get ("content_ratio" , 1 ), x .score ), reverse = True )
94+ return results
95+
96+ class AutoCrossEncoderRanker (Ranker ):
97+ def __init__ (self ):
98+ model_path = os .environ .get ("RERANK_MODEL_PATH" , default_rerank_model_path )
99+ self .tokenizer = AutoTokenizer .from_pretrained (model_path )
100+ self .model = AutoModelForSequenceClassification .from_pretrained (model_path )
101+ self .model .eval ()
102+
103+ async def rank (self , query , results : List [DocumentWithScore ]):
104+ pairs = []
105+ for idx , doc in enumerate (results ):
106+ pairs .append ((query , doc .text ))
107+ doc .rank_before = idx
108+ with torch .no_grad ():
109+ inputs = self .tokenizer (
110+ pairs ,
111+ padding = True ,
112+ truncation = True ,
113+ return_tensors = 'pt' ,
114+ max_length = 512
115+ )
116+ scores = self .model (** inputs , return_dict = True ).logits .view (- 1 ,).float ()
117+ if isinstance (scores , torch .Tensor ):
118+ scores = scores .tolist ()
119+ ranked = sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )
120+ return [x for _ , x in ranked ]
121+
122+ class FlagCrossEncoderRanker (Ranker ):
123+ def __init__ (self ):
124+ model_path = os .environ .get ("RERANK_MODEL_PATH" , default_rerank_model_path )
125+ self .reranker = FlagReranker (model_path )
126+
127+ async def rank (self , query , results : List [DocumentWithScore ]):
128+ pairs = []
129+ max_length = 512
130+ for idx , doc in enumerate (results ):
131+ pairs .append ((query [:max_length ], doc .text [:max_length ]))
132+ doc .rank_before = idx
133+ if not pairs :
134+ return []
135+ with torch .no_grad ():
136+ scores = self .reranker .compute_score (pairs , max_length = max_length )
137+ if isinstance (scores , float ):
138+ scores = [scores ]
139+ ranked = sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )
140+ return [x for _ , x in ranked ]
141+
142+ class RankerService (Ranker ):
143+ def __init__ (self ):
144+ if RERANK_BACKEND == "xinference" :
145+ self .ranker = XinferenceRanker ()
146+ elif RERANK_BACKEND == "local" :
147+ self .ranker = FlagCrossEncoderRanker ()
148+ elif RERANK_BACKEND == "jina" :
149+ self .ranker = JinaRanker ()
150+ else :
151+ raise Exception ("Unsupported backend" )
152+
153+ async def rank (self , query , results : List [DocumentWithScore ]):
154+ return await self .ranker .rank (query , results )
155+
156+ @synchronized
157+ def get_rerank_model (model_type : str = "bge-reranker-large" ):
158+ if model_type in rerank_model_cache :
159+ return rerank_model_cache [model_type ]
160+ model = RankerService ()
161+ rerank_model_cache [model_type ] = model
162+ return model
163+
164+ async def rerank (message , results ):
165+ model = get_rerank_model ()
166+ return await model .rank (message , results )
0 commit comments