Skip to content

Commit 661d7ba

Browse files
authored
Faster parquet streaming + filters with predicate pushdown (#7309)
* faster parquet streaming + add filters config param * add test
1 parent b60ebb8 commit 661d7ba

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import List, Optional
3+
from typing import List, Optional, Union
44

55
import pyarrow as pa
6+
import pyarrow.dataset as ds
67
import pyarrow.parquet as pq
78

89
import datasets
@@ -19,6 +20,7 @@ class ParquetConfig(datasets.BuilderConfig):
1920
batch_size: Optional[int] = None
2021
columns: Optional[List[str]] = None
2122
features: Optional[datasets.Features] = None
23+
filters: Optional[Union[ds.Expression, List[tuple], List[List[tuple]]]] = None
2224

2325
def __post_init__(self):
2426
super().__post_init__()
@@ -77,14 +79,25 @@ def _generate_tables(self, files):
7779
raise ValueError(
7880
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.info.features}'"
7981
)
82+
filter_expr = (
83+
pq.filters_to_expression(self.config.filters)
84+
if isinstance(self.config.filters, list)
85+
else self.config.filters
86+
)
8087
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
8188
with open(file, "rb") as f:
82-
parquet_file = pq.ParquetFile(f)
83-
if parquet_file.metadata.num_row_groups > 0:
84-
batch_size = self.config.batch_size or parquet_file.metadata.row_group(0).num_rows
89+
parquet_fragment = ds.ParquetFileFormat().make_fragment(f)
90+
if parquet_fragment.row_groups:
91+
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
8592
try:
8693
for batch_idx, record_batch in enumerate(
87-
parquet_file.iter_batches(batch_size=batch_size, columns=self.config.columns)
94+
parquet_fragment.to_batches(
95+
batch_size=batch_size,
96+
columns=self.config.columns,
97+
filter=filter_expr,
98+
batch_readahead=0,
99+
fragment_readahead=0,
100+
)
88101
):
89102
pa_table = pa.Table.from_batches([record_batch])
90103
# Uncomment for debugging (will print the Arrow table size and elements)

tests/io/test_parquet.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ def test_parquet_read_geoparquet(geoparquet_path, tmp_path):
8989
assert dataset.features[feature].dtype == expected_dtype
9090

9191

92+
def test_parquet_read_filters(parquet_path, tmp_path):
93+
cache_dir = tmp_path / "cache"
94+
filters = [("col_2", "==", 1)]
95+
dataset = ParquetDatasetReader(path_or_paths=parquet_path, cache_dir=cache_dir, filters=filters).read()
96+
97+
assert isinstance(dataset, Dataset)
98+
assert all(example["col_2"] == 1 for example in dataset)
99+
assert dataset.num_rows == 1
100+
101+
92102
def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",)):
93103
assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict))
94104
for split in splits:

0 commit comments

Comments
 (0)