Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ The following environment variables are required to run the application:
- `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings
- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html)
- `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models.
- `SIMPLE_RERANKER_MODEL_NAME` (Optional) defaults to `mixedbread-ai/mxbai-rerank-large-v1`, more options at (https://github.com/AnswerDotAI/rerankers)
- `SIMPLE_RERANKER_MODEL_TYPE` (Optional) defaults to `cross-encoder`, more options at (https://github.com/AnswerDotAI/rerankers)

Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.

Expand Down
6 changes: 6 additions & 0 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,9 @@ class QueryMultipleBody(BaseModel):
query: str
file_ids: List[str]
k: int = 4


class QueryMultipleDocs(BaseModel):
query: str
docs: List[str]
k: int = 4
44 changes: 43 additions & 1 deletion app/routes/document_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from shutil import copyfileobj
from typing import List, Iterable, TYPE_CHECKING
from concurrent.futures import ThreadPoolExecutor
from rerankers import Reranker, Document as ReRankDocument
from fastapi import (
APIRouter,
Request,
Expand Down Expand Up @@ -43,6 +44,7 @@
QueryRequestBody,
DocumentResponse,
QueryMultipleBody,
QueryMultipleDocs,
)
from app.services.vector_store.async_pg_vector import AsyncPgVector
from app.utils.document_loader import (
Expand All @@ -54,7 +56,10 @@
from app.utils.health import is_health_ok

router = APIRouter()

reranker_instance = Reranker(
model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "mixedbread-ai/mxbai-rerank-large-v1"),
model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "cross-encoder"),
)

def calculate_num_batches(total: int, batch_size: int) -> int:
"""Calculate the number of batches needed to process total items."""
Expand Down Expand Up @@ -1002,6 +1007,43 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
)
raise HTTPException(status_code=500, detail=str(e))

@router.post("/rerank")
async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
"""
Rerank documents based on relevance to a query using a reranking model.

Args:
request: The FastAPI request object
body: Contains query string, list of documents, and optional k value

Returns:
List of ranked documents with their scores
"""

try:
if not body.docs:
raise HTTPException(status_code=400, detail="docs list cannot be empty")
docs = []
for i, d in enumerate(body.docs):
docs.append(ReRankDocument(text=d, doc_id=i))

top_k = body.k

results = reranker_instance.rank(query=body.query, docs=docs)
items = results.top_k(top_k) if top_k else results

return [
{"text": getattr(r.document, "text", None), "score": r.score} for r in items
]
except Exception as e:
logger.error(
"Error in reranking documents | Query: %s | Error: %s | Traceback: %s",
body.query,
str(e),
traceback.format_exc(),
)
raise HTTPException(status_code=500, detail=str(e))


@router.post("/text")
async def extract_text_from_file(
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ services:
- DB_PORT=5432
ports:
- "8000:8000"
runtime: ${DOCKER_RUNTIME:-runc}
volumes:
- ./uploads:/app/uploads
- ~/.cache/huggingface:/root/.cache/huggingface:rw
depends_on:
- db
env_file:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,7 @@ python-magic==0.4.27
python-pptx==1.0.2
xlrd==2.0.2
pydantic==2.9.2
rerankers[transformers]==0.6.0
rerankers[flashrank]==0.6.0
chardet==5.2.0
tenacity>=9.0.0
38 changes: 38 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,41 @@ def test_extract_text_from_file(tmp_path, auth_headers):
assert json_data["file_id"] == "test_text_123"
assert json_data["filename"] == "test_text_extraction.txt"
assert json_data["known_type"] is True # text files are known types

def test_query_rerank(auth_headers):
# Successful reranking with string documents
data = {
"query": "I love you",
"docs": ["I hate you", "I really like you"],
"k": 1
}
response = client.post("/rerank", json=data, headers=auth_headers)
assert response.status_code == 200
json_data = response.json()
assert isinstance(json_data, list)
assert len(json_data) == 1
doc = json_data[0]
assert doc["text"] == "I really like you"

# Handling of the k parameter (top_k filtering)
data = {
"query": "I love you",
"docs": ["I hate you", "I really like you", "I love you too"],
"k": 2
}
response = client.post("/rerank", json=data, headers=auth_headers)
assert response.status_code == 200
json_data = response.json()
assert isinstance(json_data, list)
assert len(json_data) == 2
assert json_data[0]["text"] == "I really like you"
assert json_data[1]["text"] == "I love you too"

# Error handling for invalid inputs
data = {
"query": "I love you",
"docs": [123, 456],
"k": 1
}
response = client.post("/rerank", json=data, headers=auth_headers)
assert response.status_code == 422