Skip to content

Commit 5eec91a

Browse files
authored
Parquet: add on_bad_file argument to error/warn/skip bad files (#7806)
add on_bad_file
1 parent 02ee330 commit 5eec91a

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import Optional, Union
3+
from typing import Literal, Optional, Union
44

55
import pyarrow as pa
66
import pyarrow.dataset as ds
@@ -36,6 +36,13 @@ class ParquetConfig(datasets.BuilderConfig):
3636
Scan-specific options for Parquet fragments.
3737
This is especially useful to configure buffering and caching.
3838
39+
<Added version="4.2.0"/>
40+
on_bad_file (`Literal["error", "warn", "skip"]`, *optional*, defaults to "error")
41+
Specify what to do upon encountering a bad file (a file that can't be read). Allowed values are :
42+
* 'error', raise an Exception when a bad file is encountered.
43+
* 'warn', raise a warning when a bad file is encountered and skip that file.
44+
* 'skip', skip bad files without raising or warning when they are encountered.
45+
3946
<Added version="4.2.0"/>
4047
4148
Example:
@@ -74,6 +81,7 @@ class ParquetConfig(datasets.BuilderConfig):
7481
features: Optional[datasets.Features] = None
7582
filters: Optional[Union[ds.Expression, list[tuple], list[list[tuple]]]] = None
7683
fragment_scan_options: Optional[ds.ParquetFragmentScanOptions] = None
84+
on_bad_file: Literal["error", "warn", "skip"] = "error"
7785

7886
def __post_init__(self):
7987
super().__post_init__()
@@ -109,9 +117,22 @@ def _split_generators(self, dl_manager):
109117
# Infer features if they are stored in the arrow schema
110118
if self.info.features is None:
111119
for file in itertools.chain.from_iterable(files):
112-
with open(file, "rb") as f:
113-
self.info.features = datasets.Features.from_arrow_schema(pq.read_schema(f))
114-
break
120+
try:
121+
with open(file, "rb") as f:
122+
self.info.features = datasets.Features.from_arrow_schema(pq.read_schema(f))
123+
break
124+
except pa.ArrowInvalid as e:
125+
if self.config.on_bad_file == "error":
126+
logger.error(f"Failed to read schema from '{file}' with error {type(e).__name__}: {e}")
127+
raise
128+
elif self.config.on_bad_file == "warn":
129+
logger.warning(f"Skipping bad schema from '{file}'. {type(e).__name__}: {e}`")
130+
else:
131+
logger.debug(f"Skipping bad schema from '{file}'. {type(e).__name__}: {e}`")
132+
if self.info.features is None:
133+
raise ValueError(
134+
f"At least one valid data file must be specified, all the data_files are invalid: {self.config.data_files}"
135+
)
115136
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
116137
if self.config.columns is not None and set(self.config.columns) != set(self.info.features):
117138
self.info.features = datasets.Features(
@@ -139,11 +160,11 @@ def _generate_tables(self, files):
139160
)
140161
parquet_file_format = ds.ParquetFileFormat(default_fragment_scan_options=self.config.fragment_scan_options)
141162
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
142-
with open(file, "rb") as f:
143-
parquet_fragment = parquet_file_format.make_fragment(f)
144-
if parquet_fragment.row_groups:
145-
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
146-
try:
163+
try:
164+
with open(file, "rb") as f:
165+
parquet_fragment = parquet_file_format.make_fragment(f)
166+
if parquet_fragment.row_groups:
167+
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
147168
for batch_idx, record_batch in enumerate(
148169
parquet_fragment.to_batches(
149170
batch_size=batch_size,
@@ -158,6 +179,11 @@ def _generate_tables(self, files):
158179
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
159180
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
160181
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
161-
except ValueError as e:
162-
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
163-
raise
182+
except (pa.ArrowInvalid, ValueError) as e:
183+
if self.config.on_bad_file == "error":
184+
logger.error(f"Failed to read file '{file}' with error {type(e).__name__}: {e}")
185+
raise
186+
elif self.config.on_bad_file == "warn":
187+
logger.warning(f"Skipping bad file '{file}'. {type(e).__name__}: {e}`")
188+
else:
189+
logger.debug(f"Skipping bad file '{file}'. {type(e).__name__}: {e}`")

0 commit comments

Comments
 (0)