1515from config .settings import (
1616 RERANK_BACKEND ,
1717 RERANK_SERVICE_MODEL_UID ,
18+ RERANK_SERVICE_MODEL ,
19+ RERANK_SERVICE_TOKEN ,
1820 RERANK_SERVICE_URL ,
1921)
2022
2123default_rerank_model_path = "/data/models/bge-reranker-large"
2224
23- # Mutex and synchronized decorator (copied for self-containment as requested)
25+ # Mutex and synchronized decorator
2426mutex = Lock ()
2527rerank_model_cache = {}
2628
27-
28- # synchronized decorator
2929def synchronized (func ):
3030 def wrapper (* args , ** kwargs ):
3131 with mutex :
3232 return func (* args , ** kwargs )
33-
3433 return wrapper
3534
36-
3735class Ranker (ABC ):
38-
3936 @abstractmethod
4037 async def rank (self , query , results : List [DocumentWithScore ]):
4138 pass
4239
43-
44- class RankerService (Ranker ):
45- def __init__ (self ):
46- if RERANK_BACKEND == "xinference" :
47- self .ranker = XinferenceRanker ()
48- elif RERANK_BACKEND == "local" :
49- self .ranker = FlagCrossEncoderRanker ()
50- else :
51- raise Exception (
52- "Unsupported embedding backend" ) # Note: Error message refers to embedding backend, might be a typo in original code
53-
54- async def rank (self , query , results : List [DocumentWithScore ]):
55- return await self .ranker .rank (query , results )
56-
57-
5840class XinferenceRanker (Ranker ):
5941 def __init__ (self ):
6042 self .url = f"{ RERANK_SERVICE_URL } /v1/rerank"
6143 self .model_uid = RERANK_SERVICE_MODEL_UID
6244
6345 async def rank (self , query , results : List [DocumentWithScore ]):
64- documents = [document .text for document in results ]
65- request_body = {
46+ documents = [doc .text for doc in results ]
47+ body = {
6648 "model" : self .model_uid ,
6749 "documents" : documents ,
6850 "query" : query ,
6951 "return_documents" : False ,
7052 }
7153 async with aiohttp .ClientSession () as session :
72- async with session .post (self .url , json = request_body ) as response :
73- response_data = await response .json ()
74- if response .status != 200 :
75- raise RuntimeError (f"Failed to rerank documents, detail: { response_data ['detail' ]} " )
76- indices = [response ['index' ] for response in response_data ['results' ]]
77- return [results [index ] for index in indices ]
54+ async with session .post (self .url , json = body ) as resp :
55+ data = await resp .json ()
56+ if resp .status != 200 :
57+ raise RuntimeError (f"Failed to rerank, detail: { data ['detail' ]} " )
58+ indices = [r ["index" ] for r in data ["results" ]]
59+ return [results [i ] for i in indices ]
60+
61+ class JinaRanker (Ranker ):
62+ def __init__ (self ):
63+ self .url = RERANK_SERVICE_URL # "https://api.jina.ai/v1/rerank"
64+ self .model = RERANK_SERVICE_MODEL # "jina-reranker-v2-base-multilingual"
65+ self .auth_token = RERANK_SERVICE_TOKEN # "Bearer YOUR_JINA_TOKEN"
7866
67+ async def rank (self , query , results : List [DocumentWithScore ]):
68+ documents = [doc .text for doc in results ]
69+ body = {
70+ "model" : self .model ,
71+ "query" : query ,
72+ "top_n" : len (documents ),
73+ "documents" : documents ,
74+ "return_documents" : False
75+ }
76+ headers = {
77+ "Content-Type" : "application/json" ,
78+ "Authorization" : f"Bearer ${ self .auth_token } "
79+ }
80+ async with aiohttp .ClientSession () as session :
81+ async with session .post (self .url , headers = headers , json = body ) as resp :
82+ data = await resp .json ()
83+ if resp .status != 200 :
84+ raise RuntimeError (f"Failed to rerank, detail: { data } " )
85+ indices = [r ["index" ] for r in data ["results" ]]
86+ return [results [i ] for i in indices ]
7987
8088class ContentRatioRanker (Ranker ):
81- def __init__ (self ,
82- query ): # Note: query passed in constructor but not used in rank method? Original code behavior preserved.
89+ def __init__ (self , query ):
8390 self .query = query
8491
8592 async def rank (self , query , results : List [DocumentWithScore ]):
8693 results .sort (key = lambda x : (x .metadata .get ("content_ratio" , 1 ), x .score ), reverse = True )
8794 return results
8895
89-
9096class AutoCrossEncoderRanker (Ranker ):
9197 def __init__ (self ):
9298 model_path = os .environ .get ("RERANK_MODEL_PATH" , default_rerank_model_path )
@@ -96,63 +102,65 @@ def __init__(self):
96102
97103 async def rank (self , query , results : List [DocumentWithScore ]):
98104 pairs = []
99- for idx , result in enumerate (results ):
100- pairs .append ((query , result .text ))
101- result .rank_before = idx
102-
105+ for idx , doc in enumerate (results ):
106+ pairs .append ((query , doc .text ))
107+ doc .rank_before = idx
103108 with torch .no_grad ():
104- inputs = self .tokenizer (pairs , padding = True , truncation = True , return_tensors = 'pt' , max_length = 512 )
105- scores = self .model (** inputs , return_dict = True ).logits .view (- 1 , ).float ()
106- # Ensure scores is iterable even if only one result
107- if not isinstance (scores , (list , torch .Tensor )) or (
108- isinstance (scores , torch .Tensor ) and scores .numel () == 1 and len (results ) == 1 ):
109- scores = [scores .item ()] if isinstance (scores , torch .Tensor ) else [scores ]
110- elif isinstance (scores , torch .Tensor ):
109+ inputs = self .tokenizer (
110+ pairs ,
111+ padding = True ,
112+ truncation = True ,
113+ return_tensors = 'pt' ,
114+ max_length = 512
115+ )
116+ scores = self .model (** inputs , return_dict = True ).logits .view (- 1 ,).float ()
117+ if isinstance (scores , torch .Tensor ):
111118 scores = scores .tolist ()
112-
113- results = [x for _ , x in sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )]
114-
115- return results
116-
119+ ranked = sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )
120+ return [x for _ , x in ranked ]
117121
118122class FlagCrossEncoderRanker (Ranker ):
119123 def __init__ (self ):
120124 model_path = os .environ .get ("RERANK_MODEL_PATH" , default_rerank_model_path )
121- # self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
122- self .reranker = FlagReranker (model_path ) # use fp16 can speed up computing
125+ self .reranker = FlagReranker (model_path )
123126
124127 async def rank (self , query , results : List [DocumentWithScore ]):
125128 pairs = []
126129 max_length = 512
127- for idx , result in enumerate (results ):
128- pairs .append ((query [:max_length ], result .text [:max_length ]))
129- result .rank_before = idx
130-
130+ for idx , doc in enumerate (results ):
131+ pairs .append ((query [:max_length ], doc .text [:max_length ]))
132+ doc .rank_before = idx
131133 if not pairs :
132134 return []
133-
134135 with torch .no_grad ():
135136 scores = self .reranker .compute_score (pairs , max_length = max_length )
136- # FlagReranker returns a single float if only one pair is given
137137 if isinstance (scores , float ):
138138 scores = [scores ]
139- results = [x for _ , x in sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )]
139+ ranked = sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )
140+ return [x for _ , x in ranked ]
140141
141- return results
142+ class RankerService (Ranker ):
143+ def __init__ (self ):
144+ if RERANK_BACKEND == "xinference" :
145+ self .ranker = XinferenceRanker ()
146+ elif RERANK_BACKEND == "local" :
147+ self .ranker = FlagCrossEncoderRanker ()
148+ elif RERANK_BACKEND == "jina" :
149+ self .ranker = JinaRanker ()
150+ else :
151+ raise Exception ("Unsupported backend" )
142152
153+ async def rank (self , query , results : List [DocumentWithScore ]):
154+ return await self .ranker .rank (query , results )
143155
144156@synchronized
145157def get_rerank_model (model_type : str = "bge-reranker-large" ):
146- # self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
147- # Note: model_type parameter is not currently used to select different RankerService logic, but kept for signature consistency.
148158 if model_type in rerank_model_cache :
149159 return rerank_model_cache [model_type ]
150160 model = RankerService ()
151161 rerank_model_cache [model_type ] = model
152162 return model
153163
154-
155164async def rerank (message , results ):
156165 model = get_rerank_model ()
157- results = await model .rank (message , results )
158- return results
166+ return await model .rank (message , results )
0 commit comments