diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index c55f3d3d..3ff60c15 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -53,6 +53,7 @@ def read( working_dir: Optional[str] = "cache", parallelism: int = 4, recursive: bool = True, + read_nums: Optional[int] = None, **reader_kwargs: Any, ) -> ray.data.Dataset: """ @@ -63,6 +64,7 @@ def read( :param working_dir: Directory to cache intermediate files (PDF processing) :param parallelism: Number of parallel workers :param recursive: Whether to scan directories recursively + :param read_nums: Limit the number of documents to read :param reader_kwargs: Additional kwargs passed to readers :return: Ray Dataset containing all documents """ @@ -120,6 +122,9 @@ def read( } ) + if read_nums is not None: + combined_ds = combined_ds.limit(read_nums) + logger.info("[READ] Successfully read files from %s", input_path) return combined_ds