|
1 | 1 | from typing import Sequence, Optional, Any, Dict |
2 | 2 | from langchain_core.callbacks import Callbacks |
3 | | -from langchain_core.documents import BaseDocumentCompressor, Document |
4 | | -import requests |
5 | | - |
| 3 | +from langchain_core.documents import Document |
| 4 | +from langchain_community.embeddings import OllamaEmbeddings |
6 | 5 | from setting.models_provider.base_model_provider import MaxKBBaseModel |
| 6 | +from sklearn.metrics.pairwise import cosine_similarity |
| 7 | +from pydantic.v1 import BaseModel, Field |
7 | 8 |
|
8 | 9 |
|
9 | | -class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor): |
10 | | - api_base: Optional[str] |
11 | | - """URL of the Ollama server""" |
12 | | - model_name: Optional[str] |
13 | | - """The model name to use for reranking""" |
14 | | - api_key: Optional[str] |
| 10 | +class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): |
| 11 | + top_n: Optional[int] = Field(3, description="Number of top documents to return") |
15 | 12 |
|
16 | 13 | @staticmethod |
17 | | - def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs): |
18 | | - return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name, |
19 | | - api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3)) |
20 | | - |
21 | | - top_n: Optional[int] = 3 |
22 | | - |
23 | | - def __init__( |
24 | | - self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3, |
25 | | - api_key: Optional[str] = None |
26 | | - ): |
27 | | - super().__init__() |
28 | | - |
29 | | - if api_base is None: |
30 | | - raise ValueError("Please provide server URL") |
31 | | - |
32 | | - if model_name is None: |
33 | | - raise ValueError("Please provide the model name") |
34 | | - |
35 | | - self.api_base = api_base |
36 | | - self.model_name = model_name |
37 | | - self.api_key = api_key |
38 | | - self.top_n = top_n |
| 14 | + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): |
| 15 | + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) |
| 16 | + return OllamaReranker( |
| 17 | + model=model_name, |
| 18 | + base_url=model_credential.get('api_base'), |
| 19 | + **optional_params |
| 20 | + ) |
39 | 21 |
|
40 | 22 | def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ |
41 | 23 | Sequence[Document]: |
42 | | - """ |
43 | | - Given a query and a set of documents, rerank them using Ollama API. |
44 | | - """ |
45 | | - if not documents or len(documents) == 0: |
46 | | - return [] |
47 | | - |
48 | | - # Prepare the data to send to Ollama API |
49 | | - headers = { |
50 | | - 'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required |
51 | | - } |
52 | | - |
53 | | - # Format the documents to be sent in a format understood by Ollama's API |
54 | | - documents_text = [document.page_content for document in documents] |
55 | | - |
56 | | - # Make a POST request to Ollama's rerank API endpoint |
57 | | - payload = { |
58 | | - 'model': self.model_name, # Specify the model |
59 | | - 'query': query, |
60 | | - 'documents': documents_text, |
61 | | - 'top_n': self.top_n |
62 | | - } |
63 | | - |
64 | | - try: |
65 | | - response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload) |
66 | | - response.raise_for_status() |
67 | | - res = response.json() |
68 | | - |
69 | | - # Ensure the response contains expected results |
70 | | - if 'results' not in res: |
71 | | - raise ValueError("The API response did not contain rerank results.") |
| 24 | + """Rank documents based on their similarity to the query. |
| 25 | +
|
| 26 | + Args: |
| 27 | + query: The query text. |
| 28 | + documents: The list of document texts to rank. |
| 29 | +
|
| 30 | + Returns: |
| 31 | + List of documents sorted by relevance to the query. |
| 32 | + """ |
| 33 | + # 获取查询和文档的嵌入 |
| 34 | + query_embedding = self.embed_query(query) |
| 35 | + documents = [doc.page_content for doc in documents] |
| 36 | + document_embeddings = self.embed_documents(documents) |
| 37 | + # 计算相似度 |
| 38 | + similarities = cosine_similarity([query_embedding], document_embeddings)[0] |
| 39 | + ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n] |
| 40 | + return [ |
| 41 | + Document( |
| 42 | + page_content=doc, # 第一个值是文档内容 |
| 43 | + metadata={'relevance_score': score} # 第二个值是相似度分数 |
| 44 | + ) |
| 45 | + for doc, score in ranked_docs |
| 46 | + ] |
72 | 47 |
|
73 | | - # Convert the API response into a list of Document objects with relevance scores |
74 | | - ranked_documents = [ |
75 | | - Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']}) |
76 | | - for d in res['results'] |
77 | | - ] |
78 | | - return ranked_documents |
79 | 48 |
|
80 | | - except requests.exceptions.RequestException as e: |
81 | | - print(f"Error during API request: {e}") |
82 | | - return [] # Return an empty list if the request failed |
0 commit comments