1616
1717
1818class XInferenceReranker (MaxKBBaseModel , BaseDocumentCompressor ):
19- client : Any
2019 server_url : Optional [str ]
2120 """URL of the xinference server"""
2221 model_uid : Optional [str ]
@@ -30,10 +29,13 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3029
3130 top_n : Optional [int ] = 3
3231
33- def __init__ (
34- self , server_url : Optional [str ] = None , model_uid : Optional [str ] = None , top_n = 3 ,
35- api_key : Optional [str ] = None
36- ):
32+ def compress_documents (self , documents : Sequence [Document ], query : str , callbacks : Optional [Callbacks ] = None ) -> \
33+ Sequence [Document ]:
34+ if documents is None or len (documents ) == 0 :
35+ return []
36+ client : Any
37+ if documents is None or len (documents ) == 0 :
38+ return []
3739 try :
3840 from xinference .client import RESTfulClient
3941 except ImportError :
@@ -45,29 +47,8 @@ def __init__(
4547 " with `pip install xinference` or `pip install xinference_client`."
4648 ) from e
4749
48- super ().__init__ ()
49-
50- if server_url is None :
51- raise ValueError ("Please provide server URL" )
52-
53- if model_uid is None :
54- raise ValueError ("Please provide the model UID" )
55-
56- self .server_url = server_url
57-
58- self .model_uid = model_uid
59-
60- self .api_key = api_key
61-
62- self .client = RESTfulClient (server_url , api_key )
63-
64- self .top_n = top_n
65-
66- def compress_documents (self , documents : Sequence [Document ], query : str , callbacks : Optional [Callbacks ] = None ) -> \
67- Sequence [Document ]:
68- if documents is None or len (documents ) == 0 :
69- return []
70- model : RESTfulRerankModelHandle = self .client .get_model (self .model_uid )
50+ client = RESTfulClient (self .server_url , self .api_key )
51+ model : RESTfulRerankModelHandle = client .get_model (self .model_uid )
7152 res = model .rerank ([document .page_content for document in documents ], query , self .top_n , return_documents = True )
7253 return [Document (page_content = d .get ('document' , {}).get ('text' ),
7354 metadata = {'relevance_score' : d .get ('relevance_score' )}) for d in res .get ('results' , [])]
0 commit comments