Skip to content

Commit 1cf2c09

Browse files
refactor: remove torch.
1 parent 8a6b685 commit 1cf2c09

File tree

3 files changed

+15
-20
lines changed

3 files changed

+15
-20
lines changed

apps/models_provider/impl/local_model_provider/model/reranker.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99
from typing import Sequence, Optional, Dict, Any, ClassVar
1010

1111
import requests
12-
import torch
1312
from langchain_core.callbacks import Callbacks
1413
from langchain_core.documents import BaseDocumentCompressor, Document
1514
from transformers import AutoModelForSequenceClassification, AutoTokenizer
16-
15+
import numpy as np
1716
from models_provider.base_model_provider import MaxKBBaseModel
1817
from maxkb.const import CONFIG
1918

@@ -90,13 +89,16 @@ def compress_documents(self, documents: Sequence[Document], query: str, callback
9089
Sequence[Document]:
9190
if documents is None or len(documents) == 0:
9291
return []
93-
with torch.no_grad():
94-
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
95-
truncation=True, return_tensors='pt', max_length=512)
96-
scores = [torch.sigmoid(s).float().item() for s in
97-
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
98-
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
99-
for index
100-
in range(len(documents))]
101-
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
102-
return result
92+
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
93+
truncation=True, return_tensors='pt', max_length=512)
94+
scores = [self.sigmoid(s).float().item() for s in
95+
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
96+
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
97+
for index
98+
in range(len(documents))]
99+
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
100+
return result
101+
102+
def sigmoid(x):
103+
x = np.asarray(x, dtype=np.float64)
104+
return 1 / (1 + np.exp(-x))

installer/Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ RUN rm -rf /opt/maxkb-app/ui && \
2424
pip install poetry==2.0.0 --break-system-packages && \
2525
poetry config virtualenvs.create false && \
2626
. /opt/py3/bin/activate && \
27-
if [ "$(uname -m)" = "x86_64" ]; then sed -i 's/^torch.*/torch = {version = "2.7.1+cpu", source = "pytorch"}/g' pyproject.toml; fi && \
2827
poetry install && \
2928
find /opt/maxkb-app -depth \( -name ".git*" -o -name ".docker*" -o -name ".idea*" -o -name ".editorconfig*" -o -name ".prettierrc*" -o -name "README.md" -o -name "poetry.lock" -o -name "pyproject.toml" \) -exec rm -rf {} + && \
3029
export MAXKB_CONFIG_TYPE=ENV && python3 /opt/maxkb-app/apps/manage.py compilemessages && \

pyproject.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ langchain-mcp-adapters = "0.1.9"
3737
langchain-huggingface = "0.3.0"
3838
langchain-ollama = "0.3.4"
3939
langgraph = "0.5.3"
40-
torch = "2.7.1"
4140
sentence-transformers = "5.0.0"
4241

4342
# 云服务SDK
@@ -80,9 +79,4 @@ pylint = "3.3.7"
8079

8180
[build-system]
8281
requires = ["poetry-core"]
83-
build-backend = "poetry.core.masonry.api"
84-
85-
[[tool.poetry.source]]
86-
name = "pytorch"
87-
url = "https://download.pytorch.org/whl/cpu"
88-
priority = "explicit"
82+
build-backend = "poetry.core.masonry.api"

0 commit comments

Comments
 (0)