Skip to content

Commit 3b96385

Browse files
authored
feat: support parquet client encryption (#2674)
Closes #2642
1 parent d0081d3 commit 3b96385

File tree

12 files changed

+258
-5
lines changed

12 files changed

+258
-5
lines changed

awswrangler/distributed/ray/datasources/arrow_parquet_base_datasource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ def _read_stream(self, f: pa.NativeFile, path: str) -> Iterator[pa.Table]:
4040

4141
dataset_kwargs = arrow_parquet_args.get("dataset_kwargs", {})
4242
coerce_int96_timestamp_unit: str | None = dataset_kwargs.get("coerce_int96_timestamp_unit", None)
43+
decryption_properties = dataset_kwargs.get("decryption_properties", None)
4344

4445
table = pq.read_table(
4546
f,
4647
use_threads=use_threads,
4748
columns=columns,
4849
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
50+
decryption_properties=decryption_properties,
4951
)
5052

5153
table = _add_table_partitions(

awswrangler/distributed/ray/modin/s3/_read_parquet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ def _read_parquet_distributed(
3939
s3_additional_kwargs: dict[str, Any] | None,
4040
arrow_kwargs: dict[str, Any],
4141
bulk_read: bool,
42+
decryption_properties: pa.parquet.encryption.DecryptionConfiguration | None = None,
4243
) -> pd.DataFrame:
4344
dataset_kwargs = {}
4445
if coerce_int96_timestamp_unit:
4546
dataset_kwargs["coerce_int96_timestamp_unit"] = coerce_int96_timestamp_unit
47+
if decryption_properties:
48+
dataset_kwargs["decryption_properties"] = decryption_properties
4649

4750
dataset = read_datasource(
4851
**_resolve_datasource_parameters(

awswrangler/distributed/ray/modin/s3/_write_orc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from awswrangler import exceptions
1313
from awswrangler.distributed.ray.datasources import ArrowORCDatasink, UserProvidedKeyBlockWritePathProvider
1414
from awswrangler.distributed.ray.modin._utils import _ray_dataset_from_df
15+
from awswrangler.typing import ArrowEncryptionConfiguration
1516

1617
if TYPE_CHECKING:
1718
from mypy_boto3_s3 import S3Client
@@ -36,6 +37,7 @@ def _to_orc_distributed(
3637
filename_prefix: str | None = None,
3738
max_rows_by_file: int | None = 0,
3839
bucketing: bool = False,
40+
encryption_configuration: ArrowEncryptionConfiguration | None = None,
3941
) -> list[str]:
4042
if bucketing:
4143
# Add bucket id to the prefix

awswrangler/distributed/ray/modin/s3/_write_parquet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from awswrangler import exceptions
1313
from awswrangler.distributed.ray.datasources import ArrowParquetDatasink, UserProvidedKeyBlockWritePathProvider
1414
from awswrangler.distributed.ray.modin._utils import _ray_dataset_from_df
15+
from awswrangler.typing import ArrowEncryptionConfiguration
1516

1617
if TYPE_CHECKING:
1718
from mypy_boto3_s3 import S3Client
@@ -36,6 +37,7 @@ def _to_parquet_distributed(
3637
filename_prefix: str | None = "",
3738
max_rows_by_file: int | None = 0,
3839
bucketing: bool = False,
40+
encryption_configuration: ArrowEncryptionConfiguration | None = None,
3941
) -> list[str]:
4042
if bucketing:
4143
# Add bucket id to the prefix

awswrangler/s3/_read_parquet.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
_InternalReadTableMetadataReturnValue,
4141
_TableMetadataReader,
4242
)
43-
from awswrangler.typing import RayReadParquetSettings, _ReadTableMetadataReturnValue
43+
from awswrangler.typing import ArrowDecryptionConfiguration, RayReadParquetSettings, _ReadTableMetadataReturnValue
4444

4545
if TYPE_CHECKING:
4646
from mypy_boto3_s3 import S3Client
@@ -56,9 +56,14 @@
5656
def _pyarrow_parquet_file_wrapper(
5757
source: Any,
5858
coerce_int96_timestamp_unit: str | None = None,
59+
decryption_properties: pyarrow.parquet.encryption.DecryptionConfiguration | None = None,
5960
) -> pyarrow.parquet.ParquetFile:
6061
try:
61-
return pyarrow.parquet.ParquetFile(source=source, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit)
62+
return pyarrow.parquet.ParquetFile(
63+
source=source,
64+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
65+
decryption_properties=decryption_properties,
66+
)
6267
except pyarrow.ArrowInvalid as ex:
6368
if str(ex) == "Parquet file size is 0 bytes":
6469
_logger.warning("Ignoring empty file...")
@@ -74,6 +79,7 @@ def _read_parquet_metadata_file(
7479
use_threads: bool | int,
7580
version_id: str | None = None,
7681
coerce_int96_timestamp_unit: str | None = None,
82+
decryption_properties: pyarrow.parquet.encryption.DecryptionConfiguration | None = None,
7783
) -> pa.schema:
7884
with open_s3_object(
7985
path=path,
@@ -85,7 +91,9 @@ def _read_parquet_metadata_file(
8591
s3_additional_kwargs=s3_additional_kwargs,
8692
) as f:
8793
pq_file: pyarrow.parquet.ParquetFile | None = _pyarrow_parquet_file_wrapper(
88-
source=f, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit
94+
source=f,
95+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
96+
decryption_properties=decryption_properties,
8997
)
9098
if pq_file:
9199
return pq_file.schema.to_arrow_schema()
@@ -156,6 +164,7 @@ def _read_parquet_file(
156164
use_threads: bool | int,
157165
version_id: str | None = None,
158166
schema: pa.schema | None = None,
167+
decryption_properties: pyarrow.parquet.encryption.DecryptionConfiguration | None = None,
159168
) -> pa.Table:
160169
s3_block_size: int = FULL_READ_S3_BLOCK_SIZE if columns else -1 # One shot for a full read or see constant
161170
with open_s3_object(
@@ -176,6 +185,7 @@ def _read_parquet_file(
176185
use_threads=False,
177186
use_pandas_metadata=False,
178187
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
188+
decryption_properties=decryption_properties,
179189
)
180190
except pyarrow.ArrowInvalid as ex:
181191
if "Parquet file size is 0 bytes" in str(ex):
@@ -190,6 +200,7 @@ def _read_parquet_file(
190200
pq_file: pyarrow.parquet.ParquetFile | None = _pyarrow_parquet_file_wrapper(
191201
source=f,
192202
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
203+
decryption_properties=decryption_properties,
193204
)
194205
if pq_file is None:
195206
raise exceptions.InvalidFile(f"Invalid Parquet file: {path}")
@@ -212,6 +223,7 @@ def _read_parquet_chunked(
212223
s3_additional_kwargs: dict[str, str] | None,
213224
arrow_kwargs: dict[str, Any],
214225
version_ids: dict[str, str] | None = None,
226+
decryption_properties: pyarrow.parquet.encryption.DecryptionConfiguration | None = None,
215227
) -> Iterator[pd.DataFrame]:
216228
next_slice: pd.DataFrame | None = None
217229
batch_size = BATCH_READ_BLOCK_SIZE if chunked is True else chunked
@@ -229,6 +241,7 @@ def _read_parquet_chunked(
229241
pq_file: pyarrow.parquet.ParquetFile | None = _pyarrow_parquet_file_wrapper(
230242
source=f,
231243
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
244+
decryption_properties=decryption_properties,
232245
)
233246
if pq_file is None:
234247
continue
@@ -278,6 +291,7 @@ def _read_parquet(
278291
s3_additional_kwargs: dict[str, Any] | None,
279292
arrow_kwargs: dict[str, Any],
280293
bulk_read: bool,
294+
decryption_properties: pyarrow.parquet.encryption.DecryptionConfiguration | None = None,
281295
) -> pd.DataFrame:
282296
executor: _BaseExecutor = _get_executor(use_threads=use_threads)
283297
tables = executor.map(
@@ -291,6 +305,7 @@ def _read_parquet(
291305
itertools.repeat(use_threads),
292306
[version_ids.get(p) if isinstance(version_ids, dict) else None for p in paths],
293307
itertools.repeat(schema),
308+
itertools.repeat(decryption_properties),
294309
)
295310
return _utils.table_refs_to_df(tables, kwargs=arrow_kwargs)
296311

@@ -321,6 +336,7 @@ def read_parquet(
321336
boto3_session: boto3.Session | None = None,
322337
s3_additional_kwargs: dict[str, Any] | None = None,
323338
pyarrow_additional_kwargs: dict[str, Any] | None = None,
339+
decryption_configuration: ArrowDecryptionConfiguration | None = None,
324340
) -> pd.DataFrame | Iterator[pd.DataFrame]:
325341
"""Read Parquet file(s) from an S3 prefix or list of S3 objects paths.
326342
@@ -425,6 +441,11 @@ def read_parquet(
425441
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
426442
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
427443
e.g. pyarrow_additional_kwargs={'split_blocks': True}.
444+
decryption_configuration: typing.ArrowDecryptionConfiguration, optional
445+
``pyarrow.parquet.encryption.CryptoFactory`` and ``pyarrow.parquet.encryption.KmsConnectionConfig`` objects dict
446+
used to create a PyArrow ``CryptoFactory.file_decryption_properties`` object to forward to PyArrow reader.
447+
see: https://arrow.apache.org/docs/python/parquet.html#decryption-configuration
448+
Client Decryption is not supported in distributed mode.
428449
429450
Returns
430451
-------
@@ -508,10 +529,17 @@ def read_parquet(
508529
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
509530
)
510531

532+
decryption_properties = (
533+
decryption_configuration["crypto_factory"].file_decryption_properties(
534+
decryption_configuration["kms_connection_config"]
535+
)
536+
if decryption_configuration
537+
else None
538+
)
539+
511540
arrow_kwargs = _data_types.pyarrow2pandas_defaults(
512541
use_threads=use_threads, kwargs=pyarrow_additional_kwargs, dtype_backend=dtype_backend
513542
)
514-
515543
if chunked:
516544
return _read_parquet_chunked(
517545
s3_client=s3_client,
@@ -524,6 +552,7 @@ def read_parquet(
524552
s3_additional_kwargs=s3_additional_kwargs,
525553
arrow_kwargs=arrow_kwargs,
526554
version_ids=version_ids,
555+
decryption_properties=decryption_properties,
527556
)
528557

529558
return _read_parquet(
@@ -539,6 +568,7 @@ def read_parquet(
539568
arrow_kwargs=arrow_kwargs,
540569
version_ids=version_ids,
541570
bulk_read=bulk_read,
571+
decryption_properties=decryption_properties,
542572
)
543573

544574

@@ -563,6 +593,7 @@ def read_parquet_table(
563593
boto3_session: boto3.Session | None = None,
564594
s3_additional_kwargs: dict[str, Any] | None = None,
565595
pyarrow_additional_kwargs: dict[str, Any] | None = None,
596+
decryption_configuration: ArrowDecryptionConfiguration | None = None,
566597
) -> pd.DataFrame | Iterator[pd.DataFrame]:
567598
"""Read Apache Parquet table registered in the AWS Glue Catalog.
568599
@@ -641,6 +672,10 @@ def read_parquet_table(
641672
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
642673
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
643674
e.g. pyarrow_additional_kwargs={'split_blocks': True}.
675+
decryption_configuration: typing.ArrowDecryptionConfiguration, optional
676+
``pyarrow.parquet.encryption.CryptoFactory`` and ``pyarrow.parquet.encryption.KmsConnectionConfig`` objects dict
677+
used to create a PyArrow ``CryptoFactory.file_decryption_properties`` object to forward to PyArrow reader.
678+
Client Decryption is not supported in distributed mode.
644679
645680
Returns
646681
-------
@@ -698,6 +733,7 @@ def read_parquet_table(
698733
boto3_session=boto3_session,
699734
s3_additional_kwargs=s3_additional_kwargs,
700735
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
736+
decryption_configuration=decryption_configuration,
701737
)
702738

703739
partial_cast_function = functools.partial(

awswrangler/s3/_read_parquet.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import boto3
55
import pandas as pd
66
import pyarrow as pa
77

8-
from awswrangler.typing import RayReadParquetSettings
8+
from awswrangler.typing import ArrowDecryptionConfiguration, RayReadParquetSettings
99

1010
if TYPE_CHECKING:
1111
from mypy_boto3_s3 import S3Client
@@ -88,6 +88,7 @@ def read_parquet(
8888
boto3_session: boto3.Session | None = ...,
8989
s3_additional_kwargs: dict[str, Any] | None = ...,
9090
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
91+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
9192
) -> pd.DataFrame: ...
9293
@overload
9394
def read_parquet(
@@ -113,6 +114,7 @@ def read_parquet(
113114
boto3_session: boto3.Session | None = ...,
114115
s3_additional_kwargs: dict[str, Any] | None = ...,
115116
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
117+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
116118
) -> Iterator[pd.DataFrame]: ...
117119
@overload
118120
def read_parquet(
@@ -138,6 +140,7 @@ def read_parquet(
138140
boto3_session: boto3.Session | None = ...,
139141
s3_additional_kwargs: dict[str, Any] | None = ...,
140142
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
143+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
141144
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
142145
@overload
143146
def read_parquet(
@@ -163,6 +166,7 @@ def read_parquet(
163166
boto3_session: boto3.Session | None = ...,
164167
s3_additional_kwargs: dict[str, Any] | None = ...,
165168
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
169+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
166170
) -> Iterator[pd.DataFrame]: ...
167171
@overload
168172
def read_parquet_table(
@@ -183,6 +187,7 @@ def read_parquet_table(
183187
boto3_session: boto3.Session | None = ...,
184188
s3_additional_kwargs: dict[str, Any] | None = ...,
185189
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
190+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
186191
) -> pd.DataFrame: ...
187192
@overload
188193
def read_parquet_table(
@@ -203,6 +208,7 @@ def read_parquet_table(
203208
boto3_session: boto3.Session | None = ...,
204209
s3_additional_kwargs: dict[str, Any] | None = ...,
205210
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
211+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
206212
) -> Iterator[pd.DataFrame]: ...
207213
@overload
208214
def read_parquet_table(
@@ -223,6 +229,7 @@ def read_parquet_table(
223229
boto3_session: boto3.Session | None = ...,
224230
s3_additional_kwargs: dict[str, Any] | None = ...,
225231
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
232+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
226233
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
227234
@overload
228235
def read_parquet_table(
@@ -243,4 +250,5 @@ def read_parquet_table(
243250
boto3_session: boto3.Session | None = ...,
244251
s3_additional_kwargs: dict[str, Any] | None = ...,
245252
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
253+
decryption_configuration: ArrowDecryptionConfiguration | None = ...,
246254
) -> Iterator[pd.DataFrame]: ...

awswrangler/s3/_write.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _write_to_s3(
195195
filename_prefix: str | None = None,
196196
max_rows_by_file: int | None = 0,
197197
bucketing: bool = False,
198+
encryption_configuration: typing.ArrowEncryptionConfiguration | None = None,
198199
) -> list[str]:
199200
pass
200201

@@ -269,6 +270,7 @@ def write( # noqa: PLR0912,PLR0913,PLR0915
269270
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
270271
catalog_id: str | None,
271272
compression_ext: str,
273+
encryption_configuration: typing.ArrowEncryptionConfiguration | None,
272274
) -> typing._S3WriteDataReturnValue:
273275
# Initializing defaults
274276
partition_cols = partition_cols if partition_cols else []
@@ -349,6 +351,7 @@ def write( # noqa: PLR0912,PLR0913,PLR0915
349351
dtype=dtype,
350352
max_rows_by_file=max_rows_by_file,
351353
use_threads=use_threads,
354+
encryption_configuration=encryption_configuration,
352355
)
353356
else:
354357
columns_types: dict[str, str] = {}
@@ -417,6 +420,7 @@ def write( # noqa: PLR0912,PLR0913,PLR0915
417420
s3_additional_kwargs=s3_additional_kwargs,
418421
schema=schema,
419422
max_rows_by_file=max_rows_by_file,
423+
encryption_configuration=encryption_configuration,
420424
)
421425
if database and table:
422426
try:

awswrangler/s3/_write_orc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def _to_orc(
156156
filename_prefix: str | None = None,
157157
max_rows_by_file: int | None = 0,
158158
bucketing: bool = False,
159+
encryption_configuration: typing.ArrowEncryptionConfiguration | None = None,
159160
) -> list[str]:
160161
s3_client = s3_client if s3_client else _utils.client(service_name="s3")
161162
file_path = _get_file_path(
@@ -216,6 +217,7 @@ def _write_to_s3(
216217
filename_prefix: str | None = None,
217218
max_rows_by_file: int | None = 0,
218219
bucketing: bool = False,
220+
encryption_configuration: typing.ArrowEncryptionConfiguration | None = None,
219221
) -> list[str]:
220222
return _to_orc(
221223
df=df,
@@ -234,6 +236,7 @@ def _write_to_s3(
234236
filename_prefix=filename_prefix,
235237
max_rows_by_file=max_rows_by_file,
236238
bucketing=bucketing,
239+
encryption_configuration=encryption_configuration,
237240
)
238241

239242
def _create_glue_table(
@@ -712,4 +715,5 @@ def to_orc(
712715
athena_partition_projection_settings=athena_partition_projection_settings,
713716
catalog_id=catalog_id,
714717
compression_ext=compression_ext,
718+
encryption_configuration=None,
715719
)

0 commit comments

Comments
 (0)