Skip to content

Commit 155613f

Browse files
szymondudyczberkecanrizai
authored andcommitted
Fix doc_post_processors in document_store (#8276)
Co-authored-by: berkecanrizai <63911408+berkecanrizai@users.noreply.github.com> GitOrigin-RevId: 7099adbc821b5f7f6bac29a81d983fcef30b895f
1 parent 050b034 commit 155613f

File tree

5 files changed

+120
-7
lines changed

5 files changed

+120
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
1313

1414
### Changed
1515
- **BREAKING**: Changed the argument in `DoclingParser` from `parse_images` (bool) into `image_parsing_strategy` (Literal["llm"] | None)
16+
- **BREAKING**: `doc_post_processors` argument in the `pw.xpacks.llm.document_store.DocumentStore` now longer accepts `pw.UDF`.
1617
- Better error messages when using `pathway spawn` with multiple workers. Now error messages are printed only from the worker experiencing the error directly.
1718

18-
### Removed
19-
19+
### Fixed
20+
- `doc_post_processors` argument in the `pw.xpacks.llm.document_store.DocumentStore` had no effect. This is now fixed.
2021

2122
## [0.19.0] - 2025-02-20
2223

python/pathway/xpacks/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._typing import Doc, DocTransformer, DocTransformerCallable # isort: skip
33

44
from . import (
5+
document_store,
56
embedders,
67
llms,
78
parsers,
@@ -13,6 +14,7 @@
1314
)
1415

1516
__all__ = [
17+
"document_store",
1618
"embedders",
1719
"llms",
1820
"parsers",

python/pathway/xpacks/llm/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,12 @@ def _to_dict(element: dict | pw.Json):
9898
return element.as_dict()
9999
else:
100100
return element
101+
102+
103+
def _wrap_doc_post_processor(fun: Callable[[str, dict], tuple[str, dict]]) -> pw.UDF:
104+
@pw.udf
105+
def wrapper(text: str, metadata: pw.Json) -> tuple[str, dict]:
106+
metadata_dict = metadata.as_dict()
107+
return fun(text, metadata_dict)
108+
109+
return wrapper

python/pathway/xpacks/llm/document_store.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pathway.stdlib.indexing.data_index import _SCORE, DataIndex
2020
from pathway.stdlib.indexing.retrievers import AbstractRetrieverFactory
2121
from pathway.stdlib.ml.classifiers import _knn_lsh
22+
from pathway.xpacks.llm._utils import _wrap_doc_post_processor
2223
from pathway.xpacks.llm.utils import combine_metadata
2324

2425
from ._utils import _wrap_udf
@@ -46,7 +47,8 @@ class DocumentStore:
4647
parser: callable that parses file contents into a list of documents.
4748
splitter: callable that splits long documents.
4849
doc_post_processors: optional list of callables that modify parsed files and metadata.
49-
any callable takes two arguments (text: str, metadata: dict) and returns them as a tuple.
50+
Each doc_post_processor is a Callable that takes two arguments
51+
(text: str, metadata: dict) and returns them as a tuple.
5052
"""
5153

5254
def __init__(
@@ -56,7 +58,7 @@ def __init__(
5658
parser: Callable[[bytes], list[tuple[str, dict]]] | pw.UDF | None = None,
5759
splitter: Callable[[str], list[tuple[str, dict]]] | pw.UDF | None = None,
5860
doc_post_processors: (
59-
list[Callable[[str, dict], tuple[str, dict]] | pw.UDF] | None
61+
list[Callable[[str, dict], tuple[str, dict]]] | None
6062
) = None,
6163
):
6264
self.docs = docs
@@ -67,7 +69,9 @@ def __init__(
6769
else pathway.xpacks.llm.parsers.Utf8Parser()
6870
)
6971
self.doc_post_processors: list[pw.UDF] = (
70-
[_wrap_udf(p) for p in doc_post_processors] if doc_post_processors else []
72+
[_wrap_doc_post_processor(p) for p in doc_post_processors]
73+
if doc_post_processors
74+
else []
7175
)
7276
self.splitter: pw.UDF = (
7377
_wrap_udf(splitter)
@@ -247,6 +251,21 @@ def apply_processor(
247251
# `metadata` column: old_meta_dict -> old_meta_dict | new_meta_dict
248252
return combine_metadata(processed_docs)
249253

254+
@pw.table_transformer
255+
def apply_doc_post_processor(
256+
self, table: pw.Table, processor: pw.UDF
257+
) -> pw.Table[_DocumentSchema]:
258+
259+
processed_docs: pw.Table[DocumentStore._DocumentSchema] = table.select(
260+
text=processor(pw.this.text, pw.this.metadata)
261+
# some processors might split document into multiple parts so we flatten the results
262+
# metadata will be propagated to all new rows
263+
).select(text=pw.this.text[0], metadata=pw.this.text[1])
264+
# combine_metadata will transform our columns as follows:
265+
# `text` column: tuple[str, new_meta_dict] -> str
266+
# `metadata` column: new_meta_dict
267+
return processed_docs
268+
250269
def build_pipeline(self):
251270

252271
cleaned_tables = self._clean_tables(self.docs)
@@ -272,13 +291,13 @@ def build_pipeline(self):
272291
# POST PROCESSING
273292
self.post_processed_docs = self.parsed_docs
274293
for post_processor in self.doc_post_processors:
275-
self.post_processed_docs = self.apply_processor(
294+
self.post_processed_docs = self.apply_doc_post_processor(
276295
self.post_processed_docs, post_processor
277296
)
278297

279298
# CHUNKING
280299
self.chunked_docs: pw.Table[DocumentStore._DocumentSchema] = (
281-
self.apply_processor(self.parsed_docs, self.splitter)
300+
self.apply_processor(self.post_processed_docs, self.splitter)
282301
)
283302

284303
# INDEXING

python/pathway/xpacks/llm/tests/test_document_store.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pathway.tests.utils import assert_table_equality
2020
from pathway.xpacks.llm.document_store import DocumentStore
2121
from pathway.xpacks.llm.servers import DocumentStoreServer
22+
from pathway.xpacks.llm.tests import mocks
2223

2324

2425
class DebugStatsInputSchema(DocumentStore.StatisticsQuerySchema):
@@ -663,3 +664,84 @@ def fake_embeddings_model(x: str) -> list[float]:
663664
(val,) = rows["result"].values()
664665
assert isinstance(val, pw.Json)
665666
assert len(val.as_list()) == 2
667+
668+
669+
def test_docstore_post_processor():
670+
671+
def add_baz(text: str, metadata: dict) -> tuple:
672+
return (text + "baz", metadata)
673+
674+
docs = pw.debug.table_from_rows(
675+
schema=pw.schema_from_types(data=bytes, _metadata=dict),
676+
rows=[
677+
(
678+
"test".encode("utf-8"),
679+
{"foo": "bar"},
680+
)
681+
],
682+
)
683+
684+
index_factory = BruteForceKnnFactory(
685+
dimensions=3,
686+
reserved_space=10,
687+
embedder=mocks.fake_embeddings_model,
688+
metric=BruteForceKnnMetricKind.COS,
689+
)
690+
691+
vector_server = DocumentStore(
692+
docs, retriever_factory=index_factory, doc_post_processors=[add_baz]
693+
)
694+
695+
retrieve_queries = pw.debug.table_from_rows(
696+
schema=DocumentStore.RetrieveQuerySchema,
697+
rows=[("Foo", 1, None, None)],
698+
)
699+
700+
retrieve_outputs = vector_server.retrieve_query(retrieve_queries)
701+
_, rows = pw.debug.table_to_dicts(retrieve_outputs)
702+
(val,) = rows["result"].values()
703+
assert isinstance(val, pw.Json)
704+
(query_result,) = val.as_list() # extract the single match
705+
assert isinstance(query_result, dict)
706+
assert query_result["text"] == "testbaz"
707+
708+
709+
def test_docstore_metadata_post_processor():
710+
711+
def add_id(text: str, metadata: dict) -> tuple:
712+
metadata["id"] = 1
713+
return (text, metadata)
714+
715+
docs = pw.debug.table_from_rows(
716+
schema=pw.schema_from_types(data=bytes, _metadata=dict),
717+
rows=[
718+
(
719+
"test".encode("utf-8"),
720+
{"foo": "bar"},
721+
)
722+
],
723+
)
724+
725+
index_factory = BruteForceKnnFactory(
726+
dimensions=3,
727+
reserved_space=10,
728+
embedder=mocks.fake_embeddings_model,
729+
metric=BruteForceKnnMetricKind.COS,
730+
)
731+
732+
vector_server = DocumentStore(
733+
docs, retriever_factory=index_factory, doc_post_processors=[add_id]
734+
)
735+
736+
retrieve_queries = pw.debug.table_from_rows(
737+
schema=DocumentStore.RetrieveQuerySchema,
738+
rows=[("Foo", 1, None, None)],
739+
)
740+
741+
retrieve_outputs = vector_server.retrieve_query(retrieve_queries)
742+
_, rows = pw.debug.table_to_dicts(retrieve_outputs)
743+
(val,) = rows["result"].values()
744+
assert isinstance(val, pw.Json)
745+
(query_result,) = val.as_list() # extract the single match
746+
assert isinstance(query_result, dict)
747+
assert query_result["metadata"]["id"] == 1

0 commit comments

Comments
 (0)