Skip to content

Commit 1368d37

Browse files
wxg0103wangdan-fit2cloud
authored andcommitted
refactor: ollama support rerank
1 parent 22fb799 commit 1368d37

File tree

2 files changed

+36
-71
lines changed

2 files changed

+36
-71
lines changed

apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,3 @@ def build_model(self, model_info: Dict[str, object]):
6464
return self
6565

6666
api_base = forms.TextInputField('API URL', required=True)
67-
api_key = forms.TextInputField('API Key', required=True)
Lines changed: 36 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,48 @@
11
from typing import Sequence, Optional, Any, Dict
22
from langchain_core.callbacks import Callbacks
3-
from langchain_core.documents import BaseDocumentCompressor, Document
4-
import requests
5-
3+
from langchain_core.documents import Document
4+
from langchain_community.embeddings import OllamaEmbeddings
65
from setting.models_provider.base_model_provider import MaxKBBaseModel
6+
from sklearn.metrics.pairwise import cosine_similarity
7+
from pydantic.v1 import BaseModel, Field
78

89

9-
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
10-
api_base: Optional[str]
11-
"""URL of the Ollama server"""
12-
model_name: Optional[str]
13-
"""The model name to use for reranking"""
14-
api_key: Optional[str]
10+
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
11+
top_n: Optional[int] = Field(3, description="Number of top documents to return")
1512

1613
@staticmethod
17-
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
18-
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
19-
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
20-
21-
top_n: Optional[int] = 3
22-
23-
def __init__(
24-
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
25-
api_key: Optional[str] = None
26-
):
27-
super().__init__()
28-
29-
if api_base is None:
30-
raise ValueError("Please provide server URL")
31-
32-
if model_name is None:
33-
raise ValueError("Please provide the model name")
34-
35-
self.api_base = api_base
36-
self.model_name = model_name
37-
self.api_key = api_key
38-
self.top_n = top_n
14+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
15+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
16+
return OllamaReranker(
17+
model=model_name,
18+
base_url=model_credential.get('api_base'),
19+
**optional_params
20+
)
3921

4022
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
4123
Sequence[Document]:
42-
"""
43-
Given a query and a set of documents, rerank them using Ollama API.
44-
"""
45-
if not documents or len(documents) == 0:
46-
return []
47-
48-
# Prepare the data to send to Ollama API
49-
headers = {
50-
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
51-
}
52-
53-
# Format the documents to be sent in a format understood by Ollama's API
54-
documents_text = [document.page_content for document in documents]
55-
56-
# Make a POST request to Ollama's rerank API endpoint
57-
payload = {
58-
'model': self.model_name, # Specify the model
59-
'query': query,
60-
'documents': documents_text,
61-
'top_n': self.top_n
62-
}
63-
64-
try:
65-
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
66-
response.raise_for_status()
67-
res = response.json()
68-
69-
# Ensure the response contains expected results
70-
if 'results' not in res:
71-
raise ValueError("The API response did not contain rerank results.")
24+
"""Rank documents based on their similarity to the query.
25+
26+
Args:
27+
query: The query text.
28+
documents: The list of document texts to rank.
29+
30+
Returns:
31+
List of documents sorted by relevance to the query.
32+
"""
33+
# 获取查询和文档的嵌入
34+
query_embedding = self.embed_query(query)
35+
documents = [doc.page_content for doc in documents]
36+
document_embeddings = self.embed_documents(documents)
37+
# 计算相似度
38+
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
39+
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
40+
return [
41+
Document(
42+
page_content=doc, # 第一个值是文档内容
43+
metadata={'relevance_score': score} # 第二个值是相似度分数
44+
)
45+
for doc, score in ranked_docs
46+
]
7247

73-
# Convert the API response into a list of Document objects with relevance scores
74-
ranked_documents = [
75-
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
76-
for d in res['results']
77-
]
78-
return ranked_documents
7948

80-
except requests.exceptions.RequestException as e:
81-
print(f"Error during API request: {e}")
82-
return [] # Return an empty list if the request failed

0 commit comments

Comments
 (0)