Skip to content

Commit 2686e76

Browse files
committed
fix: xinference rerank error
--bug=1054256 --user=王孝刚 【模型】添加硅基流动的重排序模型失败 https://www.tapd.cn/57709429/s/1679612
1 parent 6cf9109 commit 2686e76

File tree

1 file changed

+9
-28
lines changed
  • apps/setting/models_provider/impl/xinference_model_provider/model

1 file changed

+9
-28
lines changed

apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717

1818
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
19-
client: Any
2019
server_url: Optional[str]
2120
"""URL of the xinference server"""
2221
model_uid: Optional[str]
@@ -30,10 +29,13 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3029

3130
top_n: Optional[int] = 3
3231

33-
def __init__(
34-
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3,
35-
api_key: Optional[str] = None
36-
):
32+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
33+
Sequence[Document]:
34+
if documents is None or len(documents) == 0:
35+
return []
36+
client: Any
37+
if documents is None or len(documents) == 0:
38+
return []
3739
try:
3840
from xinference.client import RESTfulClient
3941
except ImportError:
@@ -45,29 +47,8 @@ def __init__(
4547
" with `pip install xinference` or `pip install xinference_client`."
4648
) from e
4749

48-
super().__init__()
49-
50-
if server_url is None:
51-
raise ValueError("Please provide server URL")
52-
53-
if model_uid is None:
54-
raise ValueError("Please provide the model UID")
55-
56-
self.server_url = server_url
57-
58-
self.model_uid = model_uid
59-
60-
self.api_key = api_key
61-
62-
self.client = RESTfulClient(server_url, api_key)
63-
64-
self.top_n = top_n
65-
66-
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
67-
Sequence[Document]:
68-
if documents is None or len(documents) == 0:
69-
return []
70-
model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid)
50+
client = RESTfulClient(self.server_url, self.api_key)
51+
model: RESTfulRerankModelHandle = client.get_model(self.model_uid)
7152
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
7253
return [Document(page_content=d.get('document', {}).get('text'),
7354
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]

0 commit comments

Comments
 (0)