|
6 | 6 | @date:2024/9/2 16:42 |
7 | 7 | @desc: |
8 | 8 | """ |
9 | | -from typing import Dict |
| 9 | +from http import HTTPStatus |
| 10 | +from typing import Sequence, Optional, Any, Dict |
10 | 11 |
|
11 | | -from langchain_community.document_compressors import DashScopeRerank |
| 12 | +import dashscope |
| 13 | +from langchain_core.callbacks import Callbacks |
| 14 | +from langchain_core.documents import BaseDocumentCompressor, Document |
| 15 | +from langchain_core.documents import BaseDocumentCompressor |
12 | 16 |
|
13 | 17 | from models_provider.base_model_provider import MaxKBBaseModel |
14 | 18 |
|
15 | 19 |
|
16 | | -class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank): |
| 20 | +class AliyunBaiLianReranker(MaxKBBaseModel, BaseDocumentCompressor): |
| 21 | + model: Optional[str] |
| 22 | + api_key: Optional[str] |
| 23 | + |
| 24 | + top_n: Optional[int] = 3 # 取前 N 个最相关的结果 |
| 25 | + |
| 26 | + @staticmethod |
| 27 | + def is_cache_model(): |
| 28 | + return False |
| 29 | + |
17 | 30 | @staticmethod |
18 | 31 | def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): |
19 | | - return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'), |
| 32 | + return AliyunBaiLianReranker(model=model_name, |
| 33 | + api_key=model_credential.get('dashscope_api_key'), |
20 | 34 | top_n=model_kwargs.get('top_n', 3)) |
| 35 | + |
| 36 | + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ |
| 37 | + Sequence[Document]: |
| 38 | + if not documents: |
| 39 | + return [] |
| 40 | + |
| 41 | + texts = [doc.page_content for doc in documents] |
| 42 | + resp = dashscope.TextReRank.call( |
| 43 | + model=self.model, |
| 44 | + query=query, |
| 45 | + documents=texts, |
| 46 | + top_n=self.top_n, |
| 47 | + api_key=self.api_key, |
| 48 | + return_documents=True |
| 49 | + ) |
| 50 | + if resp.status_code == HTTPStatus.OK: |
| 51 | + return [ |
| 52 | + Document( |
| 53 | + page_content=item.get('document', {}).get('text', ''), |
| 54 | + metadata={'relevance_score': item.get('relevance_score')} |
| 55 | + ) |
| 56 | + for item in resp.output.get('results', []) |
| 57 | + ] |
| 58 | + else: |
| 59 | + return [] |
0 commit comments