From 56379030ddd448d9a7a4ca326ff4739452e2b716 Mon Sep 17 00:00:00 2001 From: Daniel Persson Date: Tue, 19 Aug 2025 09:02:21 +0200 Subject: [PATCH] feat: ReRanker endpoint using local models. --- README.md | 2 ++ app/models.py | 6 +++++ app/routes/document_routes.py | 44 ++++++++++++++++++++++++++++++++++- docker-compose.yaml | 2 ++ requirements.txt | 2 ++ tests/test_main.py | 38 ++++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bee6a1f3..18a27ab5 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,8 @@ The following environment variables are required to run the application: - `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings - `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html) - `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models. +- `SIMPLE_RERANKER_MODEL_NAME` (Optional) defaults to `mixedbread-ai/mxbai-rerank-large-v1`, more options at (https://github.com/AnswerDotAI/rerankers) +- `SIMPLE_RERANKER_MODEL_TYPE` (Optional) defaults to `cross-encoder`, more options at (https://github.com/AnswerDotAI/rerankers) Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables. diff --git a/app/models.py b/app/models.py index 835764f9..809dbd80 100644 --- a/app/models.py +++ b/app/models.py @@ -42,3 +42,9 @@ class QueryMultipleBody(BaseModel): query: str file_ids: List[str] k: int = 4 + + +class QueryMultipleDocs(BaseModel): + query: str + docs: List[str] + k: int = 4 diff --git a/app/routes/document_routes.py b/app/routes/document_routes.py index c59e1ef5..6aebe4db 100644 --- a/app/routes/document_routes.py +++ b/app/routes/document_routes.py @@ -7,6 +7,7 @@ from shutil import copyfileobj from typing import List, Iterable, TYPE_CHECKING from concurrent.futures import ThreadPoolExecutor +from rerankers import Reranker, Document as ReRankDocument from fastapi import ( APIRouter, Request, @@ -43,6 +44,7 @@ QueryRequestBody, DocumentResponse, QueryMultipleBody, + QueryMultipleDocs, ) from app.services.vector_store.async_pg_vector import AsyncPgVector from app.utils.document_loader import ( @@ -54,7 +56,10 @@ from app.utils.health import is_health_ok router = APIRouter() - +reranker_instance = Reranker( + model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "mixedbread-ai/mxbai-rerank-large-v1"), + model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "cross-encoder"), +) def calculate_num_batches(total: int, batch_size: int) -> int: """Calculate the number of batches needed to process total items.""" @@ -1002,6 +1007,43 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody ) raise HTTPException(status_code=500, detail=str(e)) +@router.post("/rerank") +async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs): + """ + Rerank documents based on relevance to a query using a reranking model. + + Args: + request: The FastAPI request object + body: Contains query string, list of documents, and optional k value + + Returns: + List of ranked documents with their scores + """ + + try: + if not body.docs: + raise HTTPException(status_code=400, detail="docs list cannot be empty") + docs = [] + for i, d in enumerate(body.docs): + docs.append(ReRankDocument(text=d, doc_id=i)) + + top_k = body.k + + results = reranker_instance.rank(query=body.query, docs=docs) + items = results.top_k(top_k) if top_k else results + + return [ + {"text": getattr(r.document, "text", None), "score": r.score} for r in items + ] + except Exception as e: + logger.error( + "Error in reranking documents | Query: %s | Error: %s | Traceback: %s", + body.query, + str(e), + traceback.format_exc(), + ) + raise HTTPException(status_code=500, detail=str(e)) + @router.post("/text") async def extract_text_from_file( diff --git a/docker-compose.yaml b/docker-compose.yaml index 225299ef..56c648d9 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -17,8 +17,10 @@ services: - DB_PORT=5432 ports: - "8000:8000" + runtime: ${DOCKER_RUNTIME:-runc} volumes: - ./uploads:/app/uploads + - ~/.cache/huggingface:/root/.cache/huggingface:rw depends_on: - db env_file: diff --git a/requirements.txt b/requirements.txt index 4f1f96c8..c7ede3e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,5 +37,7 @@ python-magic==0.4.27 python-pptx==1.0.2 xlrd==2.0.2 pydantic==2.9.2 +rerankers[transformers]==0.6.0 +rerankers[flashrank]==0.6.0 chardet==5.2.0 tenacity>=9.0.0 diff --git a/tests/test_main.py b/tests/test_main.py index df940d2c..d00e5146 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -275,3 +275,41 @@ def test_extract_text_from_file(tmp_path, auth_headers): assert json_data["file_id"] == "test_text_123" assert json_data["filename"] == "test_text_extraction.txt" assert json_data["known_type"] is True # text files are known types + +def test_query_rerank(auth_headers): + # Successful reranking with string documents + data = { + "query": "I love you", + "docs": ["I hate you", "I really like you"], + "k": 1 + } + response = client.post("/rerank", json=data, headers=auth_headers) + assert response.status_code == 200 + json_data = response.json() + assert isinstance(json_data, list) + assert len(json_data) == 1 + doc = json_data[0] + assert doc["text"] == "I really like you" + + # Handling of the k parameter (top_k filtering) + data = { + "query": "I love you", + "docs": ["I hate you", "I really like you", "I love you too"], + "k": 2 + } + response = client.post("/rerank", json=data, headers=auth_headers) + assert response.status_code == 200 + json_data = response.json() + assert isinstance(json_data, list) + assert len(json_data) == 2 + assert json_data[0]["text"] == "I really like you" + assert json_data[1]["text"] == "I love you too" + + # Error handling for invalid inputs + data = { + "query": "I love you", + "docs": [123, 456], + "k": 1 + } + response = client.post("/rerank", json=data, headers=auth_headers) + assert response.status_code == 422