99import json
1010import warnings
1111from collections .abc import Callable
12- from typing import TYPE_CHECKING , Iterable , TypeAlias
12+ from enum import Enum
13+ from typing import TYPE_CHECKING , Iterable
1314
1415import jmespath
1516import requests
@@ -46,6 +47,11 @@ def _get_jmespath_filter(metadata_filter: str, filepath_globpattern: str) -> str
4647 return None
4748
4849
50+ class IndexingStatus (str , Enum ):
51+ INDEXED = "INDEXED"
52+ INGESTED = "INGESTED"
53+
54+
4955class DocumentStore :
5056 """
5157 Builds a document indexing pipeline for processing documents and querying closest documents
@@ -213,7 +219,17 @@ class FilterSchema(pw.Schema):
213219 default_value = None , description = "An optional Glob pattern for the file path"
214220 )
215221
216- InputsQuerySchema : TypeAlias = FilterSchema
222+ class InputsQuerySchema (pw .Schema ):
223+ metadata_filter : str | None = pw .column_definition (
224+ default_value = None , description = "Metadata filter in JMESPath format"
225+ )
226+ filepath_globpattern : str | None = pw .column_definition (
227+ default_value = None , description = "An optional Glob pattern for the file path"
228+ )
229+ return_status : bool = pw .column_definition (
230+ default_value = False ,
231+ description = "Flag whether _indexing_status should be returned for each file" ,
232+ )
217233
218234 class InputsResultSchema (pw .Schema ):
219235 result : list [pw .Json ]
@@ -298,9 +314,18 @@ def build_pipeline(self):
298314
299315 docs = pw .Table .concat_reindex (* cleaned_tables )
300316
317+ @pw .udf
318+ def add_file_id (metadata : pw .Json , id ) -> dict :
319+ metadata_dict = metadata .as_dict ()
320+ id = str (id )
321+ metadata_dict ["_file_id" ] = id
322+ return metadata_dict
323+
301324 # rename columns to be consistent with the rest of the pipeline
302325 self .input_docs : pw .Table [DocumentStore ._RawDocumentSchema ] = docs .select (
303- text = pw .this .data , metadata = pw .this ._metadata
326+ text = pw .this .data ,
327+ metadata = add_file_id (pw .this ._metadata , pw .this .id ),
328+ # path=pw.this.metadata["path"].as_str(),
304329 )
305330
306331 # PARSING
@@ -327,6 +352,32 @@ def build_pipeline(self):
327352 metadata_column = self .chunked_docs .metadata ,
328353 )
329354
355+ progress_table = self .input_docs .select (
356+ file_id = pw .this .metadata ["_file_id" ].as_str (),
357+ metadata = pw .this .metadata ,
358+ )
359+ chunked_stats = (
360+ self .chunked_docs .with_columns (
361+ file_id = pw .this .metadata ["_file_id" ].as_str ()
362+ )
363+ .groupby (pw .this .file_id )
364+ .reduce (
365+ file_id = pw .this .file_id ,
366+ chunks = pw .reducers .count (),
367+ )
368+ )
369+ self .progress_table = (
370+ progress_table .join_left (
371+ chunked_stats ,
372+ progress_table .file_id == chunked_stats .file_id ,
373+ )
374+ .select (
375+ * pw .left ,
376+ chunks = pw .right .chunks ,
377+ )
378+ .with_columns (is_parsed = pw .this .chunks .is_not_none ())
379+ )
380+
330381 parsed_docs_with_metadata = self .parsed_docs .with_columns (
331382 modified = pw .this .metadata ["modified_at" ].as_int (),
332383 indexed = pw .this .metadata ["seen_at" ].as_int (),
@@ -394,34 +445,68 @@ def inputs_query(
394445 # TODO: compare this approach to first joining queries to dicuments, then filtering,
395446 # then grouping to get each response.
396447 # The "dumb" tuple approach has more work precomputed for an all inputs query
397- all_metas = self .input_docs .reduce (
398- metadatas = pw .reducers .tuple (pw .this .metadata )
448+ all_metas = self .progress_table .reduce (
449+ metadatas = pw .reducers .tuple (pw .this .metadata ),
450+ is_parsed = pw .reducers .tuple (pw .this .is_parsed ),
399451 )
400452
401453 input_queries = self .merge_filters (input_queries )
402454
403455 @pw .udf
404456 def format_inputs (
405- metadatas : list [pw .Json ] | None , metadata_filter : str | None
457+ metadatas : list [pw .Json ] | None ,
458+ metadata_filter : str | None ,
459+ return_status : bool ,
460+ is_parsed : list [bool ],
406461 ) -> list [pw .Json ]:
407462 metadatas = metadatas if metadatas is not None else []
408463 assert metadatas is not None
464+
465+ def remove_id (m ):
466+ metadata_dict = m .as_dict ()
467+ del metadata_dict ["_file_id" ]
468+ return pw .Json (metadata_dict )
469+
470+ metadatas = [remove_id (m ) for m in metadatas ]
409471 if metadata_filter :
410472 metadatas = [
411473 m
412474 for m in metadatas
413475 if jmespath .search (
414- metadata_filter , m .value , options = _knn_lsh ._glob_options
476+ metadata_filter , m .as_dict () , options = _knn_lsh ._glob_options
415477 )
416478 ]
417479
480+ if return_status :
481+ metadatas = [
482+ pw .Json (
483+ {
484+ "_indexing_status" : (
485+ IndexingStatus .INDEXED
486+ if status
487+ else IndexingStatus .INGESTED
488+ ),
489+ ** m .as_dict (),
490+ }
491+ )
492+ for (m , status ) in zip (metadatas , is_parsed )
493+ ]
494+
418495 return metadatas
419496
420497 input_results = input_queries .join_left (all_metas , id = input_queries .id ).select (
421- all_metas .metadatas , input_queries .metadata_filter
498+ all_metas .metadatas ,
499+ input_queries .metadata_filter ,
500+ input_queries .return_status ,
501+ all_metas .is_parsed ,
422502 )
423503 input_results = input_results .select (
424- result = format_inputs (pw .this .metadatas , pw .this .metadata_filter )
504+ result = format_inputs (
505+ pw .this .metadatas ,
506+ pw .this .metadata_filter ,
507+ pw .this .return_status ,
508+ pw .this .is_parsed ,
509+ )
425510 )
426511 return input_results
427512
@@ -623,6 +708,7 @@ def get_input_files(
623708 self ,
624709 metadata_filter : str | None = None ,
625710 filepath_globpattern : str | None = None ,
711+ return_status : bool = False ,
626712 ):
627713 """
628714 Fetch information on documents in the the vector store.
@@ -633,13 +719,16 @@ def get_input_files(
633719 satisfying this filtering.
634720 filepath_globpattern: optional glob pattern specifying which documents
635721 will be searched for this query.
722+ return_status: flag telling whether `_indexing_status` should be returned
723+ for each document
636724 """
637725 url = self .url + "/v1/inputs"
638726 response = requests .post (
639727 url ,
640728 json = {
641729 "metadata_filter" : metadata_filter ,
642730 "filepath_globpattern" : filepath_globpattern ,
731+ "return_status" : return_status ,
643732 },
644733 headers = self ._get_request_headers (),
645734 timeout = self .timeout ,
0 commit comments