Skip to content

Commit 1055bbe

Browse files
Introducing reranked RAG question answerer (#132)
1 parent b22bb00 commit 1055bbe

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
55
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66
## [Unreleased]
77

8+
### Changed
9+
- `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` now supports document reranking. This enables two-stage retrieval where initial vector similarity search is followed by reranking to improve document relevance ordering.
10+
811
### Added
912
- JetStream extension is now supported in both NATS read and write connectors.
1013

python/pathway/xpacks/llm/question_answering.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright © 2024 Pathway
22
import json
3+
import logging
34
from abc import ABC, abstractmethod
45
from dataclasses import dataclass, field
56
from typing import TYPE_CHECKING, Any, Callable
@@ -450,6 +451,8 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer):
450451
Args:
451452
llm: LLM instance for question answering. See https://pathway.com/developers/api-docs/pathway-xpacks-llm/llms for available models.
452453
indexer: Indexing object for search & retrieval to be used for context augmentation.
454+
reranker: Reranker instance to evaluate document relevance. See https://pathway.com/developers/api-docs/pathway-xpacks-llm/rerankers for available models.
455+
If ``None``, behaves as standard RAG without reranking. Defaults to ``None``.
453456
default_llm_name: Default LLM model to be used in queries, only used if ``model`` parameter in post request is not specified.
454457
Omitting or setting this to ``None`` will default to the model name set during LLM's initialization.
455458
prompt_template: Template for document question answering with short response.
@@ -460,12 +463,14 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer):
460463
Defaults to ``SimpleContextProcessor`` that keeps the 'path' metadata and joins the documents with double new lines.
461464
summarize_template: Template for text summarization. Defaults to ``pathway.xpacks.llm.prompts.prompt_summarize``.
462465
search_topk: Top k parameter for the retrieval. Adjusts number of chunks in the context.
466+
rerank_topk: Number of top-scoring documents to retain after reranking, when a reranker is provided.
467+
If ``None``, defaults to half of ``search_topk``. Defaults to ``None``.
463468
464469
465470
Example:
466471
467472
>>> import pathway as pw # doctest: +SKIP
468-
>>> from pathway.xpacks.llm import embedders, splitters, llms, parsers # doctest: +SKIP
473+
>>> from pathway.xpacks.llm import embedders, splitters, llms, parsers, rerankers # doctest: +SKIP
469474
>>> from pathway.xpacks.llm.vector_store import VectorStoreServer # doctest: +SKIP
470475
>>> from pathway.udfs import DiskCache, ExponentialBackoffRetryStrategy # doctest: +SKIP
471476
>>> from pathway.xpacks.llm.question_answering import BaseRAGQuestionAnswerer # doctest: +SKIP
@@ -492,10 +497,15 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer):
492497
... cache_strategy=DiskCache(),
493498
... temperature=0.05,
494499
... )
500+
>>> reranker = rerankers.CrossEncoderReranker( # doctest: +SKIP
501+
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
502+
... cache_strategy=DiskCache(),
503+
... )
495504
>>> prompt_template = "Answer the question. Context: {context}. Question: {query}" # doctest: +SKIP
496505
>>> rag = BaseRAGQuestionAnswerer( # doctest: +SKIP
497506
... llm=chat,
498507
... indexer=vector_server,
508+
... reranker=reranker,
499509
... prompt_template=prompt_template,
500510
... )
501511
>>> app = QASummaryRestServer(app_host, app_port, rag) # doctest: +SKIP
@@ -514,10 +524,13 @@ def __init__(
514524
) = SimpleContextProcessor(),
515525
summarize_template: pw.UDF = prompts.prompt_summarize,
516526
search_topk: int = 6,
527+
reranker: pw.UDF | None = None,
528+
rerank_topk: int | None = None,
517529
) -> None:
518530

519531
self.llm = llm
520532
self.indexer = indexer
533+
self.reranker = reranker
521534

522535
if default_llm_name is None:
523536
default_llm_name = llm.model
@@ -542,8 +555,19 @@ def __init__(
542555

543556
self.summarize_template = summarize_template
544557
self.search_topk = search_topk
558+
self.rerank_topk = rerank_topk
545559
self.server: None | QASummaryRestServer = None
546560

561+
# Check reranker settings mismatch
562+
if (self.reranker is None) != (rerank_topk is None):
563+
logging.warning(
564+
"Incomplete reranker configuration: "
565+
"Both 'reranker' and 'rerank_topk' must be specified. "
566+
"Reranking disabled."
567+
)
568+
self.reranker = None
569+
self.rerank_topk = None
570+
547571
def _init_schemas(self, default_llm_name: str | None = None) -> None:
548572
"""Initialize API schemas with optional and non-optional arguments."""
549573

@@ -563,6 +587,56 @@ class SummarizeQuerySchema(pw.Schema):
563587
self.StatisticsQuerySchema = self.indexer.StatisticsQuerySchema
564588
self.InputsQuerySchema = self.indexer.InputsQuerySchema
565589

590+
def _apply_reranking(self, pw_ai_results: pw.Table) -> pw.Table:
591+
"""Apply reranking to retrieved documents."""
592+
593+
@pw.udf
594+
def add_score_to_doc(doc: pw.Json, score: float) -> dict:
595+
return {**doc.as_dict(), "reranker_score": score}
596+
597+
# Flatten docs into rows
598+
pw_ai_results_exploded = pw_ai_results.flatten(
599+
pw_ai_results.docs, origin_id='query_id'
600+
)
601+
602+
# Apply reranker
603+
pw_ai_results_scored = pw_ai_results_exploded.select(
604+
pw.this.prompt,
605+
pw.this.model,
606+
pw.this.filters,
607+
pw.this.return_context_docs,
608+
pw.this.query_id,
609+
doc=pw.this.docs,
610+
reranker_score=self.reranker(pw.this.docs["text"], pw.this.prompt), # type: ignore
611+
)
612+
613+
pw_ai_results_scored = pw_ai_results_scored.await_futures()
614+
615+
# Add score and sort
616+
pw_ai_results_scored = pw_ai_results_scored.with_columns(
617+
sort_key=-pw.this.reranker_score,
618+
doc_with_score=add_score_to_doc(pw.this.doc, pw.this.reranker_score),
619+
)
620+
621+
# Reassemble documents
622+
pw_ai_results = (
623+
pw_ai_results_scored.groupby(pw.this.query_id, sort_by=pw.this.sort_key)
624+
.reduce(
625+
query_id=pw.this.query_id,
626+
prompt=pw.reducers.any(pw.this.prompt),
627+
model=pw.reducers.any(pw.this.model),
628+
filters=pw.reducers.any(pw.this.filters),
629+
return_context_docs=pw.reducers.any(pw.this.return_context_docs),
630+
docs=pw.reducers.tuple(pw.this.doc_with_score),
631+
)
632+
.with_id(pw.this.query_id)
633+
)
634+
635+
# Keep only top k
636+
return pw_ai_results.with_columns(
637+
docs=_limit_documents(pw.this.docs, k=self.rerank_topk)
638+
)
639+
566640
@pw.table_transformer
567641
def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
568642
"""Answer a question based on the available information."""
@@ -578,6 +652,9 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
578652
docs=pw.this.result,
579653
)
580654

655+
if self.reranker is not None:
656+
pw_ai_results = self._apply_reranking(pw_ai_results)
657+
581658
pw_ai_results += pw_ai_results.select(
582659
context=self.docs_to_context_transformer(pw.this.docs)
583660
)

0 commit comments

Comments
 (0)