Skip to content

Commit b29b61f

Browse files
committed
fix: Handle suggestions from copilot.
1 parent 278e770 commit b29b61f

File tree

4 files changed

+59
-17
lines changed

4 files changed

+59
-17
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ The following environment variables are required to run the application:
9292
- `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings
9393
- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html)
9494
- `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models.
95+
- `SIMPLE_RERANKER_MODEL_NAME` (Optional) defaults to `mixedbread-ai/mxbai-rerank-large-v1`, more options at (https://github.com/AnswerDotAI/rerankers)
96+
- `SIMPLE_RERANKER_MODEL_TYPE` (Optional) defaults to `cross-encoder`, more options at (https://github.com/AnswerDotAI/rerankers)
9597

9698
Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.
9799

app/routes/document_routes.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@
4242
from app.utils.health import is_health_ok
4343

4444
router = APIRouter()
45-
reRankerInstance = Reranker(
46-
model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME"),
47-
model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE"),
48-
lang=os.getenv("SIMPLE_RERANKER_LANG"),
45+
reranker_instance = Reranker(
46+
model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "mixedbread-ai/mxbai-rerank-large-v1"),
47+
model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "cross-encoder"),
4948
)
5049

51-
5250
def get_user_id(request: Request, entity_id: str = None) -> str:
5351
"""Extract user ID from request or entity_id."""
5452
if not hasattr(request.state, "user"):
@@ -711,23 +709,27 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
711709

712710
@router.post("/rerank")
713711
async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
712+
"""
713+
Rerank documents based on relevance to a query using a reranking model.
714+
715+
Args:
716+
request: The FastAPI request object
717+
body: Contains query string, list of documents, and optional k value
718+
719+
Returns:
720+
List of ranked documents with their scores
721+
"""
722+
714723
try:
724+
if not body.docs:
725+
raise HTTPException(status_code=400, detail="docs list cannot be empty")
715726
docs = []
716727
for i, d in enumerate(body.docs):
717-
if isinstance(d, str):
718-
docs.append(ReRankDocument(text=d, doc_id=i))
719-
else:
720-
docs.append(
721-
ReRankDocument(
722-
text=d.get("text", ""),
723-
doc_id=d.get("doc_id", i),
724-
metadata=d.get("metadata", {}) or {},
725-
)
726-
)
728+
docs.append(ReRankDocument(text=d, doc_id=i))
727729

728730
top_k = body.k
729731

730-
results = reRankerInstance.rank(query=body.query, docs=docs)
732+
results = reranker_instance.rank(query=body.query, docs=docs)
731733
items = results.top_k(top_k) if top_k else results
732734

733735
return [

docker-compose.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ services:
1717
- DB_PORT=5432
1818
ports:
1919
- "8000:8000"
20-
runtime: nvidia
20+
runtime: ${DOCKER_RUNTIME:-runc}
2121
volumes:
2222
- ./uploads:/app/uploads
2323
- ~/.cache/huggingface:/root/.cache/huggingface:rw

tests/test_main.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,41 @@ def test_extract_text_from_file(tmp_path, auth_headers):
263263
assert json_data["file_id"] == "test_text_123"
264264
assert json_data["filename"] == "test_text_extraction.txt"
265265
assert json_data["known_type"] is True # text files are known types
266+
267+
def test_query_rerank(auth_headers):
268+
# Successful reranking with string documents
269+
data = {
270+
"query": "I love you",
271+
"docs": ["I hate you", "I really like you"],
272+
"k": 1
273+
}
274+
response = client.post("/rerank", json=data, headers=auth_headers)
275+
assert response.status_code == 200
276+
json_data = response.json()
277+
assert isinstance(json_data, list)
278+
assert len(json_data) == 1
279+
doc = json_data[0]
280+
assert doc["text"] == "I really like you"
281+
282+
# Handling of the k parameter (top_k filtering)
283+
data = {
284+
"query": "I love you",
285+
"docs": ["I hate you", "I really like you", "I love you too"],
286+
"k": 2
287+
}
288+
response = client.post("/rerank", json=data, headers=auth_headers)
289+
assert response.status_code == 200
290+
json_data = response.json()
291+
assert isinstance(json_data, list)
292+
assert len(json_data) == 2
293+
assert json_data[0]["text"] == "I really like you"
294+
assert json_data[1]["text"] == "I love you too"
295+
296+
# Error handling for invalid inputs
297+
data = {
298+
"query": "I love you",
299+
"docs": [123, 456],
300+
"k": 1
301+
}
302+
response = client.post("/rerank", json=data, headers=auth_headers)
303+
assert response.status_code == 422

0 commit comments

Comments
 (0)