Skip to content

Commit 8c9eb63

Browse files
authored
feat: Add ability to pass schema to s3.read_parquet (#2328)
* feat: Allow to pass pyarrow.Schema to wr.s3.read_parquet() * Fix file handle for read_table * Add packaging dependency * Raise an error if reading an empty file * Throw an exception when file size is 0 * [skip ci] Add warning
1 parent 6aa22bc commit 8c9eb63

File tree

5 files changed

+52
-14
lines changed

5 files changed

+52
-14
lines changed

awswrangler/distributed/ray/datasources/arrow_parquet_datasource.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_handle_read_os_error,
3030
)
3131

32+
from awswrangler import exceptions
3233
from awswrangler._arrow import _add_table_partitions, _df_to_table
3334
from awswrangler.distributed.ray import ray_remote
3435
from awswrangler.distributed.ray.datasources.arrow_parquet_base_datasource import ArrowParquetBaseDatasource
@@ -243,6 +244,10 @@ def __init__(
243244
self._metadata = meta_provider.prefetch_file_metadata(pq_ds.pieces, **prefetch_remote_args) or []
244245
except OSError as e:
245246
_handle_read_os_error(e, paths)
247+
except pyarrow.ArrowInvalid as ex:
248+
if "Parquet file size is 0 bytes" in str(ex):
249+
raise exceptions.InvalidFile(f"Invalid Parquet file. {str(ex)}")
250+
raise
246251
self._pq_ds = pq_ds
247252
self._meta_provider = meta_provider
248253
self._inferred_schema = inferred_schema

awswrangler/s3/_read_parquet.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import itertools
66
import logging
7+
import warnings
78
from typing import (
89
TYPE_CHECKING,
910
Any,
@@ -20,6 +21,7 @@
2021
import pyarrow as pa
2122
import pyarrow.dataset
2223
import pyarrow.parquet
24+
from packaging import version
2325
from typing_extensions import Literal
2426

2527
from awswrangler import _data_types, _utils, exceptions
@@ -54,7 +56,8 @@
5456

5557

5658
def _pyarrow_parquet_file_wrapper(
57-
source: Any, coerce_int96_timestamp_unit: Optional[str] = None
59+
source: Any,
60+
coerce_int96_timestamp_unit: Optional[str] = None,
5861
) -> pyarrow.parquet.ParquetFile:
5962
try:
6063
return pyarrow.parquet.ParquetFile(source=source, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit)
@@ -154,6 +157,7 @@ def _read_parquet_file(
154157
s3_additional_kwargs: Optional[Dict[str, str]],
155158
use_threads: Union[bool, int],
156159
version_id: Optional[str] = None,
160+
schema: Optional[pa.schema] = None,
157161
) -> pa.Table:
158162
s3_block_size: int = FULL_READ_S3_BLOCK_SIZE if columns else -1 # One shot for a full read or see constant
159163
with open_s3_object(
@@ -165,14 +169,35 @@ def _read_parquet_file(
165169
s3_additional_kwargs=s3_additional_kwargs,
166170
s3_client=s3_client,
167171
) as f:
168-
pq_file: Optional[pyarrow.parquet.ParquetFile] = _pyarrow_parquet_file_wrapper(
169-
source=f,
170-
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
171-
)
172-
if pq_file is None:
173-
raise exceptions.InvalidFile(f"Invalid Parquet file: {path}")
172+
if schema and version.parse(pa.__version__) >= version.parse("8.0.0"):
173+
try:
174+
table = pyarrow.parquet.read_table(
175+
f,
176+
columns=columns,
177+
schema=schema,
178+
use_threads=False,
179+
use_pandas_metadata=False,
180+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
181+
)
182+
except pyarrow.ArrowInvalid as ex:
183+
if "Parquet file size is 0 bytes" in str(ex):
184+
raise exceptions.InvalidFile(f"Invalid Parquet file: {path}")
185+
raise
186+
else:
187+
if schema:
188+
warnings.warn(
189+
"Your version of pyarrow does not support reading with schema. Consider an upgrade to pyarrow 8+.",
190+
UserWarning,
191+
)
192+
pq_file: Optional[pyarrow.parquet.ParquetFile] = _pyarrow_parquet_file_wrapper(
193+
source=f,
194+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
195+
)
196+
if pq_file is None:
197+
raise exceptions.InvalidFile(f"Invalid Parquet file: {path}")
198+
table = pq_file.read(columns=columns, use_threads=False, use_pandas_metadata=False)
174199
return _add_table_partitions(
175-
table=pq_file.read(columns=columns, use_threads=False, use_pandas_metadata=False),
200+
table=table,
176201
path=path,
177202
path_root=path_root,
178203
)
@@ -262,6 +287,7 @@ def _read_parquet( # pylint: disable=W0613
262287
itertools.repeat(s3_additional_kwargs),
263288
itertools.repeat(use_threads),
264289
[version_ids.get(p) if isinstance(version_ids, dict) else None for p in paths],
290+
itertools.repeat(schema),
265291
)
266292
return _utils.table_refs_to_df(tables, kwargs=arrow_kwargs)
267293

@@ -281,6 +307,7 @@ def read_parquet(
281307
columns: Optional[List[str]] = None,
282308
validate_schema: bool = False,
283309
coerce_int96_timestamp_unit: Optional[str] = None,
310+
schema: Optional[pa.Schema] = None,
284311
last_modified_begin: Optional[datetime.datetime] = None,
285312
last_modified_end: Optional[datetime.datetime] = None,
286313
version_id: Optional[Union[str, Dict[str, str]]] = None,
@@ -359,6 +386,8 @@ def read_parquet(
359386
coerce_int96_timestamp_unit : str, optional
360387
Cast timestamps that are stored in INT96 format to a particular resolution (e.g. "ms").
361388
Setting to None is equivalent to "ns" and therefore INT96 timestamps are inferred as in nanoseconds.
389+
schema : pyarrow.Schema, optional
390+
Schema to use whem reading the file.
362391
last_modified_begin : datetime, optional
363392
Filter S3 objects by Last modified date.
364393
Filter is only applied after listing all objects.
@@ -462,7 +491,6 @@ def read_parquet(
462491
version_ids = _check_version_id(paths=paths, version_id=version_id)
463492

464493
# Create PyArrow schema based on file metadata, columns filter, and partitions
465-
schema: Optional[pa.schema] = None
466494
if validate_schema and not bulk_read:
467495
metadata_reader = _ParquetTableMetadataReader()
468496
schema = metadata_reader.validate_schemas(

poetry.lock

Lines changed: 3 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pandas = ">=1.2.0,!=1.5.0,<3.0.0" # Exclusion per: https://github.com/aws/aws-sd
3434
numpy = "^1.18"
3535
pyarrow = ">=7.0.0"
3636
typing-extensions = "^4.4.0"
37+
packaging = "^23.1"
3738

3839
# Databases
3940
redshift-connector = { version = "^2.0.0", optional = true }

tests/unit/test_s3_parquet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,14 +630,19 @@ def test_parquet_compression(path, compression) -> None:
630630

631631

632632
@pytest.mark.parametrize("use_threads", [True, False, 2])
633-
def test_empty_file(path, use_threads):
633+
@pytest.mark.parametrize(
634+
"schema", [None, pa.schema([pa.field("c0", pa.int64()), pa.field("c1", pa.int64()), pa.field("par", pa.string())])]
635+
)
636+
def test_empty_file(path, use_threads, schema):
634637
df = pd.DataFrame({"c0": [1, 2, 3], "c1": [None, None, None], "par": ["a", "b", "c"]})
635638
df.index = df.index.astype("Int64")
636639
df["c0"] = df["c0"].astype("Int64")
637640
df["par"] = df["par"].astype("string")
638641
wr.s3.to_parquet(df, path, index=True, dataset=True, partition_cols=["par"])
639642
bucket, key = wr._utils.parse_path(f"{path}test.csv")
640643
boto3.client("s3").put_object(Body=b"", Bucket=bucket, Key=key)
644+
with pytest.raises(wr.exceptions.InvalidFile):
645+
wr.s3.read_parquet(path, use_threads=use_threads, ignore_empty=False, schema=schema)
641646
df2 = wr.s3.read_parquet(path, dataset=True, use_threads=use_threads)
642647
df2["par"] = df2["par"].astype("string")
643648
assert_pandas_equals(df, df2)

0 commit comments

Comments
 (0)