1+ #!/usr/bin/env python3
2+ # -*- coding: utf-8 -*-
3+ import json
4+ import os
5+ from abc import ABC , abstractmethod
6+ from threading import Lock
7+ from typing import Any , List
8+
9+ import aiohttp
10+ import torch
11+ from transformers import AutoModelForSequenceClassification , AutoTokenizer
12+ from FlagEmbedding import FlagReranker
13+
14+ from aperag .query .query import DocumentWithScore
15+ from config .settings import (
16+ RERANK_BACKEND ,
17+ RERANK_SERVICE_MODEL_UID ,
18+ RERANK_SERVICE_URL ,
19+ )
20+
21+ default_rerank_model_path = "/data/models/bge-reranker-large"
22+
23+ # Mutex and synchronized decorator (copied for self-containment as requested)
24+ mutex = Lock ()
25+ rerank_model_cache = {}
26+
27+
28+ # synchronized decorator
29+ def synchronized (func ):
30+ def wrapper (* args , ** kwargs ):
31+ with mutex :
32+ return func (* args , ** kwargs )
33+
34+ return wrapper
35+
36+
37+ class Ranker (ABC ):
38+
39+ @abstractmethod
40+ async def rank (self , query , results : List [DocumentWithScore ]):
41+ pass
42+
43+
44+ class RankerService (Ranker ):
45+ def __init__ (self ):
46+ if RERANK_BACKEND == "xinference" :
47+ self .ranker = XinferenceRanker ()
48+ elif RERANK_BACKEND == "local" :
49+ self .ranker = FlagCrossEncoderRanker ()
50+ else :
51+ raise Exception (
52+ "Unsupported embedding backend" ) # Note: Error message refers to embedding backend, might be a typo in original code
53+
54+ async def rank (self , query , results : List [DocumentWithScore ]):
55+ return await self .ranker .rank (query , results )
56+
57+
58+ class XinferenceRanker (Ranker ):
59+ def __init__ (self ):
60+ self .url = f"{ RERANK_SERVICE_URL } /v1/rerank"
61+ self .model_uid = RERANK_SERVICE_MODEL_UID
62+
63+ async def rank (self , query , results : List [DocumentWithScore ]):
64+ documents = [document .text for document in results ]
65+ request_body = {
66+ "model" : self .model_uid ,
67+ "documents" : documents ,
68+ "query" : query ,
69+ "return_documents" : False ,
70+ }
71+ async with aiohttp .ClientSession () as session :
72+ async with session .post (self .url , json = request_body ) as response :
73+ response_data = await response .json ()
74+ if response .status != 200 :
75+ raise RuntimeError (f"Failed to rerank documents, detail: { response_data ['detail' ]} " )
76+ indices = [response ['index' ] for response in response_data ['results' ]]
77+ return [results [index ] for index in indices ]
78+
79+
80+ class ContentRatioRanker (Ranker ):
81+ def __init__ (self ,
82+ query ): # Note: query passed in constructor but not used in rank method? Original code behavior preserved.
83+ self .query = query
84+
85+ async def rank (self , query , results : List [DocumentWithScore ]):
86+ results .sort (key = lambda x : (x .metadata .get ("content_ratio" , 1 ), x .score ), reverse = True )
87+ return results
88+
89+
90+ class AutoCrossEncoderRanker (Ranker ):
91+ def __init__ (self ):
92+ model_path = os .environ .get ("RERANK_MODEL_PATH" , default_rerank_model_path )
93+ self .tokenizer = AutoTokenizer .from_pretrained (model_path )
94+ self .model = AutoModelForSequenceClassification .from_pretrained (model_path )
95+ self .model .eval ()
96+
97+ async def rank (self , query , results : List [DocumentWithScore ]):
98+ pairs = []
99+ for idx , result in enumerate (results ):
100+ pairs .append ((query , result .text ))
101+ result .rank_before = idx
102+
103+ with torch .no_grad ():
104+ inputs = self .tokenizer (pairs , padding = True , truncation = True , return_tensors = 'pt' , max_length = 512 )
105+ scores = self .model (** inputs , return_dict = True ).logits .view (- 1 , ).float ()
106+ # Ensure scores is iterable even if only one result
107+ if not isinstance (scores , (list , torch .Tensor )) or (
108+ isinstance (scores , torch .Tensor ) and scores .numel () == 1 and len (results ) == 1 ):
109+ scores = [scores .item ()] if isinstance (scores , torch .Tensor ) else [scores ]
110+ elif isinstance (scores , torch .Tensor ):
111+ scores = scores .tolist ()
112+
113+ results = [x for _ , x in sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )]
114+
115+ return results
116+
117+
118+ class FlagCrossEncoderRanker (Ranker ):
119+ def __init__ (self ):
120+ model_path = os .environ .get ("RERANK_MODEL_PATH" , default_rerank_model_path )
121+ # self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
122+ self .reranker = FlagReranker (model_path ) # use fp16 can speed up computing
123+
124+ async def rank (self , query , results : List [DocumentWithScore ]):
125+ pairs = []
126+ max_length = 512
127+ for idx , result in enumerate (results ):
128+ pairs .append ((query [:max_length ], result .text [:max_length ]))
129+ result .rank_before = idx
130+
131+ if not pairs :
132+ return []
133+
134+ with torch .no_grad ():
135+ scores = self .reranker .compute_score (pairs , max_length = max_length )
136+ # FlagReranker returns a single float if only one pair is given
137+ if isinstance (scores , float ):
138+ scores = [scores ]
139+ results = [x for _ , x in sorted (zip (scores , results ), key = lambda k : k [0 ], reverse = True )]
140+
141+ return results
142+
143+
144+ @synchronized
145+ def get_rerank_model (model_type : str = "bge-reranker-large" ):
146+ # self.reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) #use fp16 can speed up computing
147+ # Note: model_type parameter is not currently used to select different RankerService logic, but kept for signature consistency.
148+ if model_type in rerank_model_cache :
149+ return rerank_model_cache [model_type ]
150+ model = RankerService ()
151+ rerank_model_cache [model_type ] = model
152+ return model
153+
154+
155+ async def rerank (message , results ):
156+ model = get_rerank_model ()
157+ results = await model .rank (message , results )
158+ return results
0 commit comments