Skip to content

Commit d213e33

Browse files
szymondudyczKamilPiechowiak
authored andcommitted
Refactoring Deckretriever (#7671)
Co-authored-by: Kamil Piechowiak <32928185+KamilPiechowiak@users.noreply.github.com> GitOrigin-RevId: 0212b7d459cf1767ed68635aa42fff179c435fd4
1 parent da48f52 commit d213e33

File tree

4 files changed

+167
-13
lines changed

4 files changed

+167
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
66
## [Unreleased]
77

88
### Added
9+
- `pw.xpacks.llm.document_store.SlidesDocumentStore`, which is a subclass of `pw.xpacks.llm.document_store.DocumentStore` customized for retrieving slides from presentations.
910
- `pw.temporal.inactivity_detection` and `pw.temporal.utc_now` functions allowing for alerting and other time dependent usecases
1011

1112
### Changed
1213
- `pw.Table.concat`, `pw.Table.with_id`, `pw.Table.with_id_from` no longer perform checks if ids are unique. It improves memory usage.
1314
- table operations that store values (like `pw.Table.join`, `pw.Table.update_cells`) no longer store columns that are not used downstream.
1415
- `append_only` column property is now propagated better (there are more places where we can infer it).
16+
- **BREAKING**: Unused arguments from the constructor `pw.xpacks.llm.question_answering.DeckRetriever` are no longer accepted.
1517

1618
### Fixed
1719
- `query_as_of_now` of `pw.stdlib.indexing.DataIndex` and `pw.stdlib.indexing.HybridIndex` now work in constant memory for infinite query stream (no query-related data is kept after query is answered).

python/pathway/xpacks/llm/document_store.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from collections.abc import Callable
11-
from typing import TYPE_CHECKING, Iterable
11+
from typing import TYPE_CHECKING, Iterable, TypeAlias
1212

1313
import jmespath
1414

@@ -185,7 +185,7 @@ class FilterSchema(pw.Schema):
185185
default_value=None, description="An optional Glob pattern for the file path"
186186
)
187187

188-
InputsQuerySchema = FilterSchema
188+
InputsQuerySchema: TypeAlias = FilterSchema
189189

190190
class InputsResultSchema(pw.Schema):
191191
result: list[pw.Json]
@@ -365,7 +365,7 @@ def _get_jmespath_filter(
365365

366366
@pw.table_transformer
367367
def inputs_query(
368-
self, input_queries: pw.Table[InputsQuerySchema] # type: ignore
368+
self, input_queries: pw.Table[InputsQuerySchema]
369369
) -> pw.Table[InputsResultSchema]:
370370
"""
371371
Query ``DocumentStore`` for the list of input documents.
@@ -448,3 +448,64 @@ def retrieve_query(
448448
@property
449449
def index(self) -> DataIndex:
450450
return self._retriever
451+
452+
453+
class SlidesDocumentStore(DocumentStore):
454+
"""
455+
Document store for the ``slide-search`` application.
456+
Builds a document indexing pipeline and starts an HTTP REST server.
457+
458+
Adds to the ``DocumentStore`` a new method ``parsed_documents`` a set of
459+
documents metadata after the parsing and document post processing stages.
460+
"""
461+
462+
excluded_response_metadata = ["b64_image"]
463+
464+
@pw.table_transformer
465+
def parsed_documents_query(
466+
self,
467+
parse_docs_queries: pw.Table[DocumentStore.InputsQuerySchema],
468+
) -> pw.Table:
469+
"""
470+
Query the SlidesDocumentStore for the list of documents with the associated
471+
metadata after the parsing stage.
472+
"""
473+
docs = self.parsed_docs
474+
475+
all_metas = docs.reduce(metadatas=pw.reducers.tuple(pw.this.metadata))
476+
477+
parse_docs_queries = self.merge_filters(parse_docs_queries)
478+
479+
@pw.udf
480+
def format_inputs(
481+
metadatas: list[pw.Json] | None,
482+
metadata_filter: str | None,
483+
) -> list[pw.Json]:
484+
metadatas = metadatas if metadatas is not None else []
485+
if metadata_filter:
486+
metadatas = [
487+
m
488+
for m in metadatas
489+
if jmespath.search(
490+
metadata_filter, m.value, options=_knn_lsh._glob_options
491+
)
492+
]
493+
494+
metadata_list: list[dict] = [m.as_dict() for m in metadatas]
495+
496+
for metadata in metadata_list:
497+
for metadata_key in self.excluded_response_metadata:
498+
metadata.pop(metadata_key, None)
499+
500+
return [pw.Json(m) for m in metadata_list]
501+
502+
input_results = parse_docs_queries.join_left(
503+
all_metas, id=parse_docs_queries.id
504+
).select(
505+
all_metas.metadatas,
506+
parse_docs_queries.metadata_filter,
507+
)
508+
input_results = input_results.select(
509+
result=format_inputs(pw.this.metadatas, pw.this.metadata_filter)
510+
)
511+
return input_results

python/pathway/xpacks/llm/question_answering.py

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
from pathway.internals import ColumnReference, Table, udfs
1212
from pathway.stdlib.indexing import DataIndex
1313
from pathway.xpacks.llm import Doc, llms, prompts
14-
from pathway.xpacks.llm.document_store import DocumentStore
14+
from pathway.xpacks.llm.document_store import DocumentStore, SlidesDocumentStore
1515
from pathway.xpacks.llm.llms import BaseChat, prompt_chat_single_qa
1616
from pathway.xpacks.llm.prompts import prompt_qa_geometric_rag
17-
from pathway.xpacks.llm.vector_store import VectorStoreClient, VectorStoreServer
17+
from pathway.xpacks.llm.vector_store import (
18+
SlidesVectorStoreServer,
19+
VectorStoreClient,
20+
VectorStoreServer,
21+
)
1822

1923
if TYPE_CHECKING:
20-
from pathway.xpacks.llm.servers import QASummaryRestServer
24+
from pathway.xpacks.llm.servers import QARestServer, QASummaryRestServer
2125

2226

2327
@pw.udf
@@ -455,14 +459,23 @@ def summarize_query(self, summarize_queries: pw.Table) -> pw.Table:
455459

456460
@pw.table_transformer
457461
def retrieve(self, retrieve_queries: pw.Table) -> pw.Table:
462+
"""
463+
Retrieve documents from the index.
464+
"""
458465
return self.indexer.retrieve_query(retrieve_queries)
459466

460467
@pw.table_transformer
461468
def statistics(self, statistics_queries: pw.Table) -> pw.Table:
469+
"""
470+
Get statistics about indexed files.
471+
"""
462472
return self.indexer.statistics_query(statistics_queries)
463473

464474
@pw.table_transformer
465475
def list_documents(self, list_documents_queries: pw.Table) -> pw.Table:
476+
"""
477+
Get list of documents from the retriever.
478+
"""
466479
return self.indexer.inputs_query(list_documents_queries)
467480

468481
def build_server(
@@ -682,14 +695,45 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
682695
return result
683696

684697

685-
class DeckRetriever(BaseRAGQuestionAnswerer):
686-
"""Class for slides search."""
698+
class DeckRetriever(BaseQuestionAnswerer):
699+
"""
700+
Builds the logic for the Retriever of slides.
701+
702+
Args:
703+
indexer: document store for parsing and indexing slides.
704+
search_topk: Number of slides to be returned by the `answer_query` method.
705+
"""
687706

688707
excluded_response_metadata = ["b64_image"]
689708

709+
def __init__(
710+
self,
711+
indexer: SlidesDocumentStore | SlidesVectorStoreServer,
712+
*,
713+
search_topk: int = 6,
714+
) -> None:
715+
self.indexer = indexer
716+
self._init_schemas()
717+
self.search_topk = search_topk
718+
719+
self.server: None | QARestServer = None
720+
self._pending_endpoints: list[tuple] = []
721+
722+
def _init_schemas(
723+
self,
724+
) -> None:
725+
class PWAIQuerySchema(pw.Schema):
726+
prompt: str
727+
filters: str | None = pw.column_definition(default_value=None)
728+
729+
self.AnswerQuerySchema = PWAIQuerySchema
730+
self.RetrieveQuerySchema = self.indexer.RetrieveQuerySchema
731+
self.StatisticsQuerySchema = self.indexer.StatisticsQuerySchema
732+
self.InputsQuerySchema = self.indexer.InputsQuerySchema
733+
690734
@pw.table_transformer
691735
def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table:
692-
"""Return similar docs from the index."""
736+
"""Return slides similar to the given query."""
693737

694738
pw_ai_results = pw_ai_queries + self.indexer.retrieve_query(
695739
pw_ai_queries.select(
@@ -720,6 +764,46 @@ def _format_results(docs: pw.Json) -> pw.Json:
720764

721765
return pw_ai_results
722766

767+
@pw.table_transformer
768+
def retrieve(self, retrieve_queries: pw.Table) -> pw.Table:
769+
return self.indexer.retrieve_query(retrieve_queries)
770+
771+
@pw.table_transformer
772+
def statistics(self, statistics_queries: pw.Table) -> pw.Table:
773+
return self.indexer.statistics_query(statistics_queries)
774+
775+
@pw.table_transformer
776+
def list_documents(self, list_documents_queries: pw.Table) -> pw.Table:
777+
return self.indexer.parsed_documents_query(list_documents_queries)
778+
779+
def build_server(
780+
self,
781+
host: str,
782+
port: int,
783+
**rest_kwargs,
784+
):
785+
warn(
786+
"build_server method is deprecated. Instead, use explicitly a server from pw.xpacks.llm.servers.",
787+
DeprecationWarning,
788+
stacklevel=2,
789+
)
790+
# circular import
791+
from pathway.xpacks.llm.servers import QARestServer
792+
793+
self.server = QARestServer(host, port, self, **rest_kwargs)
794+
795+
def run_server(self, *args, **kwargs):
796+
warn(
797+
"run_server method is deprecated. Instead, use explicitly a server from pw.xpacks.llm.servers.",
798+
DeprecationWarning,
799+
stacklevel=2,
800+
)
801+
if self.server is None:
802+
raise ValueError(
803+
"HTTP server is not built, initialize it with `build_server`"
804+
)
805+
self.server.run(*args, **kwargs)
806+
723807

724808
def send_post_request(
725809
url: str, data: dict, headers: dict = {}, timeout: int | None = None

python/pathway/xpacks/llm/vector_store.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import threading
1515
from collections.abc import Callable, Coroutine
16-
from typing import TYPE_CHECKING, cast
16+
from typing import TYPE_CHECKING, TypeAlias, cast
1717

1818
import jmespath
1919
import requests
@@ -330,7 +330,7 @@ class FilterSchema(pw.Schema):
330330
default_value=None, description="An optional Glob pattern for the file path"
331331
)
332332

333-
InputsQuerySchema = FilterSchema
333+
InputsQuerySchema: TypeAlias = FilterSchema
334334

335335
@staticmethod
336336
def merge_filters(queries: pw.Table):
@@ -363,7 +363,7 @@ def _get_jmespath_filter(
363363

364364
@pw.table_transformer
365365
def inputs_query(
366-
self, input_queries: pw.Table[InputsQuerySchema] # type:ignore
366+
self, input_queries: pw.Table[InputsQuerySchema]
367367
) -> pw.Table[InputResultSchema]:
368368
docs = self._graph["docs"]
369369
# TODO: compare this approach to first joining queries to dicuments, then filtering,
@@ -576,7 +576,7 @@ class SlidesVectorStoreServer(VectorStoreServer):
576576
@pw.table_transformer
577577
def inputs_query(
578578
self,
579-
input_queries: pw.Table[VectorStoreServer.InputsQuerySchema], # type:ignore
579+
input_queries: pw.Table[VectorStoreServer.InputsQuerySchema],
580580
) -> pw.Table:
581581
docs = self._graph["parsed_docs"]
582582

@@ -617,6 +617,13 @@ def format_inputs(
617617
)
618618
return input_results
619619

620+
@pw.table_transformer
621+
def parsed_documents_query(
622+
self,
623+
parse_docs_queries: pw.Table[VectorStoreServer.InputsQuerySchema],
624+
) -> pw.Table:
625+
return self.inputs_query(parse_docs_queries)
626+
620627

621628
class VectorStoreClient:
622629
"""

0 commit comments

Comments
 (0)