Skip to content

Commit cf7ecf7

Browse files
berkecanrizaiManul from Pathway
authored andcommitted
make metadata column in vecstore inputs optional (#7913)
GitOrigin-RevId: 33c6857542c21351ce9f26967a5984c72e9ddbb9
1 parent 11acea1 commit cf7ecf7

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

python/pathway/xpacks/llm/tests/test_vector_store.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,32 @@ def fake_embeddings_model(x: str) -> list[float]:
406406
(query_result,) = val.as_list() # extract the single match
407407
assert isinstance(query_result, dict)
408408
assert query_result["text"] # just check if some text was returned
409+
410+
411+
def test_docstore_on_table_without_metadata():
412+
@pw.udf
413+
def fake_embeddings_model(x: str) -> list[float]:
414+
return [1.0, 1.0, 0.0]
415+
416+
docs = pw.debug.table_from_rows(
417+
schema=pw.schema_from_types(data=bytes),
418+
rows=[("test".encode("utf-8"),)],
419+
)
420+
421+
vector_server = VectorStoreServer(
422+
docs,
423+
embedder=fake_embeddings_model,
424+
)
425+
426+
retrieve_queries = pw.debug.table_from_rows(
427+
schema=vector_server.RetrieveQuerySchema,
428+
rows=[("Foo", 1, None, None)],
429+
)
430+
431+
retrieve_outputs = vector_server.retrieve_query(retrieve_queries)
432+
_, rows = pw.debug.table_to_dicts(retrieve_outputs)
433+
(val,) = rows["result"].values()
434+
assert isinstance(val, pw.Json)
435+
(query_result,) = val.as_list() # extract the single match
436+
assert isinstance(query_result, dict)
437+
assert query_result["text"] == "test" # just check if some text was returned

python/pathway/xpacks/llm/vector_store.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import json
1313
import logging
1414
import threading
15+
import warnings
1516
from collections.abc import Callable, Coroutine
16-
from typing import TYPE_CHECKING, TypeAlias, cast
17+
from typing import TYPE_CHECKING, Iterable, TypeAlias, cast
1718

1819
import jmespath
1920
import requests
@@ -206,18 +207,39 @@ def generic_transformer(x: str) -> list[tuple[str, dict]]:
206207
**kwargs,
207208
)
208209

210+
def _clean_tables(
211+
self, docs: pw.Table | Iterable[pw.Table]
212+
) -> tuple[pw.Table, ...]:
213+
if isinstance(docs, pw.Table):
214+
docs = [docs]
215+
216+
def _clean_table(doc: pw.Table) -> pw.Table:
217+
if "_metadata" not in doc.column_names():
218+
warnings.warn(
219+
f"`_metadata` column is not present in Table {doc}. Filtering will not work for this Table"
220+
)
221+
doc = doc.with_columns(_metadata=dict())
222+
223+
return doc.select(pw.this.data, pw.this._metadata)
224+
225+
return tuple([_clean_table(doc) for doc in docs])
226+
209227
def _build_graph(self) -> dict:
210228
"""
211229
Builds the pathway computation graph for indexing documents and serving queries.
212230
"""
213231
docs_s = self.docs
232+
214233
if not docs_s:
215234
raise ValueError(
216235
"""Please provide at least one data source, e.g. read files from disk:
217236
218237
pw.io.fs.read('./sample_docs', format='binary', mode='static', with_metadata=True)
219238
"""
220239
)
240+
241+
docs_s = self._clean_tables(docs_s)
242+
221243
if len(docs_s) == 1:
222244
(docs,) = docs_s
223245
else:

0 commit comments

Comments
 (0)