44import functools
55import itertools
66import logging
7+ import warnings
78from typing import (
89 TYPE_CHECKING ,
910 Any ,
2021import pyarrow as pa
2122import pyarrow .dataset
2223import pyarrow .parquet
24+ from packaging import version
2325from typing_extensions import Literal
2426
2527from awswrangler import _data_types , _utils , exceptions
5456
5557
5658def _pyarrow_parquet_file_wrapper (
57- source : Any , coerce_int96_timestamp_unit : Optional [str ] = None
59+ source : Any ,
60+ coerce_int96_timestamp_unit : Optional [str ] = None ,
5861) -> pyarrow .parquet .ParquetFile :
5962 try :
6063 return pyarrow .parquet .ParquetFile (source = source , coerce_int96_timestamp_unit = coerce_int96_timestamp_unit )
@@ -154,6 +157,7 @@ def _read_parquet_file(
154157 s3_additional_kwargs : Optional [Dict [str , str ]],
155158 use_threads : Union [bool , int ],
156159 version_id : Optional [str ] = None ,
160+ schema : Optional [pa .schema ] = None ,
157161) -> pa .Table :
158162 s3_block_size : int = FULL_READ_S3_BLOCK_SIZE if columns else - 1 # One shot for a full read or see constant
159163 with open_s3_object (
@@ -165,14 +169,35 @@ def _read_parquet_file(
165169 s3_additional_kwargs = s3_additional_kwargs ,
166170 s3_client = s3_client ,
167171 ) as f :
168- pq_file : Optional [pyarrow .parquet .ParquetFile ] = _pyarrow_parquet_file_wrapper (
169- source = f ,
170- coerce_int96_timestamp_unit = coerce_int96_timestamp_unit ,
171- )
172- if pq_file is None :
173- raise exceptions .InvalidFile (f"Invalid Parquet file: { path } " )
172+ if schema and version .parse (pa .__version__ ) >= version .parse ("8.0.0" ):
173+ try :
174+ table = pyarrow .parquet .read_table (
175+ f ,
176+ columns = columns ,
177+ schema = schema ,
178+ use_threads = False ,
179+ use_pandas_metadata = False ,
180+ coerce_int96_timestamp_unit = coerce_int96_timestamp_unit ,
181+ )
182+ except pyarrow .ArrowInvalid as ex :
183+ if "Parquet file size is 0 bytes" in str (ex ):
184+ raise exceptions .InvalidFile (f"Invalid Parquet file: { path } " )
185+ raise
186+ else :
187+ if schema :
188+ warnings .warn (
189+ "Your version of pyarrow does not support reading with schema. Consider an upgrade to pyarrow 8+." ,
190+ UserWarning ,
191+ )
192+ pq_file : Optional [pyarrow .parquet .ParquetFile ] = _pyarrow_parquet_file_wrapper (
193+ source = f ,
194+ coerce_int96_timestamp_unit = coerce_int96_timestamp_unit ,
195+ )
196+ if pq_file is None :
197+ raise exceptions .InvalidFile (f"Invalid Parquet file: { path } " )
198+ table = pq_file .read (columns = columns , use_threads = False , use_pandas_metadata = False )
174199 return _add_table_partitions (
175- table = pq_file . read ( columns = columns , use_threads = False , use_pandas_metadata = False ) ,
200+ table = table ,
176201 path = path ,
177202 path_root = path_root ,
178203 )
@@ -262,6 +287,7 @@ def _read_parquet( # pylint: disable=W0613
262287 itertools .repeat (s3_additional_kwargs ),
263288 itertools .repeat (use_threads ),
264289 [version_ids .get (p ) if isinstance (version_ids , dict ) else None for p in paths ],
290+ itertools .repeat (schema ),
265291 )
266292 return _utils .table_refs_to_df (tables , kwargs = arrow_kwargs )
267293
@@ -281,6 +307,7 @@ def read_parquet(
281307 columns : Optional [List [str ]] = None ,
282308 validate_schema : bool = False ,
283309 coerce_int96_timestamp_unit : Optional [str ] = None ,
310+ schema : Optional [pa .Schema ] = None ,
284311 last_modified_begin : Optional [datetime .datetime ] = None ,
285312 last_modified_end : Optional [datetime .datetime ] = None ,
286313 version_id : Optional [Union [str , Dict [str , str ]]] = None ,
@@ -359,6 +386,8 @@ def read_parquet(
359386 coerce_int96_timestamp_unit : str, optional
360387 Cast timestamps that are stored in INT96 format to a particular resolution (e.g. "ms").
361388 Setting to None is equivalent to "ns" and therefore INT96 timestamps are inferred as in nanoseconds.
389+ schema : pyarrow.Schema, optional
390+ Schema to use whem reading the file.
362391 last_modified_begin : datetime, optional
363392 Filter S3 objects by Last modified date.
364393 Filter is only applied after listing all objects.
@@ -462,7 +491,6 @@ def read_parquet(
462491 version_ids = _check_version_id (paths = paths , version_id = version_id )
463492
464493 # Create PyArrow schema based on file metadata, columns filter, and partitions
465- schema : Optional [pa .schema ] = None
466494 if validate_schema and not bulk_read :
467495 metadata_reader = _ParquetTableMetadataReader ()
468496 schema = metadata_reader .validate_schemas (
0 commit comments