|
9 | 9 | from typing import Sequence, Optional, Dict, Any, ClassVar |
10 | 10 |
|
11 | 11 | import requests |
12 | | -import torch |
13 | 12 | from langchain_core.callbacks import Callbacks |
14 | 13 | from langchain_core.documents import BaseDocumentCompressor, Document |
15 | 14 | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
16 | | - |
| 15 | +import numpy as np |
17 | 16 | from models_provider.base_model_provider import MaxKBBaseModel |
18 | 17 | from maxkb.const import CONFIG |
19 | 18 |
|
@@ -90,13 +89,16 @@ def compress_documents(self, documents: Sequence[Document], query: str, callback |
90 | 89 | Sequence[Document]: |
91 | 90 | if documents is None or len(documents) == 0: |
92 | 91 | return [] |
93 | | - with torch.no_grad(): |
94 | | - inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, |
95 | | - truncation=True, return_tensors='pt', max_length=512) |
96 | | - scores = [torch.sigmoid(s).float().item() for s in |
97 | | - self.client(**inputs, return_dict=True).logits.view(-1, ).float()] |
98 | | - result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) |
99 | | - for index |
100 | | - in range(len(documents))] |
101 | | - result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) |
102 | | - return result |
| 92 | + inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, |
| 93 | + truncation=True, return_tensors='pt', max_length=512) |
| 94 | + scores = [self.sigmoid(s).float().item() for s in |
| 95 | + self.client(**inputs, return_dict=True).logits.view(-1, ).float()] |
| 96 | + result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) |
| 97 | + for index |
| 98 | + in range(len(documents))] |
| 99 | + result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) |
| 100 | + return result |
| 101 | + |
| 102 | + def sigmoid(x): |
| 103 | + x = np.asarray(x, dtype=np.float64) |
| 104 | + return 1 / (1 + np.exp(-x)) |
0 commit comments