|
42 | 42 | from app.utils.health import is_health_ok |
43 | 43 |
|
44 | 44 | router = APIRouter() |
| 45 | +reRankerInstance = Reranker( |
| 46 | + model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME"), |
| 47 | + model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE"), |
| 48 | + lang=os.getenv("SIMPLE_RERANKER_LANG"), |
| 49 | +) |
45 | 50 |
|
46 | 51 |
|
47 | 52 | def get_user_id(request: Request, entity_id: str = None) -> str: |
@@ -704,6 +709,40 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody |
704 | 709 | ) |
705 | 710 | raise HTTPException(status_code=500, detail=str(e)) |
706 | 711 |
|
| 712 | +@router.post("/rerank") |
| 713 | +async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs): |
| 714 | + try: |
| 715 | + docs = [] |
| 716 | + for i, d in enumerate(body.docs): |
| 717 | + if isinstance(d, str): |
| 718 | + docs.append(ReRankDocument(text=d, doc_id=i)) |
| 719 | + else: |
| 720 | + docs.append( |
| 721 | + ReRankDocument( |
| 722 | + text=d.get("text", ""), |
| 723 | + doc_id=d.get("doc_id", i), |
| 724 | + metadata=d.get("metadata", {}) or {}, |
| 725 | + ) |
| 726 | + ) |
| 727 | + |
| 728 | + top_k = body.k |
| 729 | + |
| 730 | + results = reRankerInstance.rank(query=body.query, docs=docs) |
| 731 | + items = results.top_k(top_k) if top_k else results |
| 732 | + |
| 733 | + return [ |
| 734 | + {"text": getattr(r.document, "text", None), "score": r.score} for r in items |
| 735 | + ] |
| 736 | + except Exception as e: |
| 737 | + logger.error( |
| 738 | + "Error in reranking documents | Query: %s | Error: %s | Traceback: %s", |
| 739 | + body.query, |
| 740 | + str(e), |
| 741 | + traceback.format_exc(), |
| 742 | + ) |
| 743 | + raise HTTPException(status_code=500, detail=str(e)) |
| 744 | + |
| 745 | + |
707 | 746 | @router.post("/text") |
708 | 747 | async def extract_text_from_file( |
709 | 748 | request: Request, |
|
0 commit comments