Skip to content

Commit 9d7cdb3

Browse files
feat: Support reading with PyArrow-backed types (#2292)
1 parent 442d362 commit 9d7cdb3

28 files changed

+592
-43
lines changed

awswrangler/_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import botocore.config
1010
import pandas as pd
11+
from typing_extensions import Literal
1112

1213
from awswrangler import exceptions
1314
from awswrangler.typing import AthenaCacheSettings
@@ -48,6 +49,7 @@ class _ConfigArg(NamedTuple):
4849
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
4950
"chunksize": _ConfigArg(dtype=int, nullable=False, enforced=True),
5051
"suppress_warnings": _ConfigArg(dtype=bool, nullable=False, default=False, loaded=True),
52+
"dtype_backend": _ConfigArg(dtype=str, nullable=True),
5153
# Endpoints URLs
5254
"s3_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True, loaded=True),
5355
"athena_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True, loaded=True),
@@ -411,6 +413,15 @@ def suppress_warnings(self) -> bool:
411413
def suppress_warnings(self, value: bool) -> None:
412414
self._set_config_value(key="suppress_warnings", value=value)
413415

416+
@property
417+
def dtype_backend(self) -> Literal["numpy_nullable", "pyarrow", None]:
418+
"""Property dtype_backend."""
419+
return cast(Literal["numpy_nullable", "pyarrow", None], self["dtype_backend"])
420+
421+
@dtype_backend.setter
422+
def dtype_backend(self, value: Literal["numpy_nullable", "pyarrow", None]) -> None:
423+
self._set_config_value(key="dtype_backend", value=value)
424+
414425
@property
415426
def s3_endpoint_url(self) -> Optional[str]:
416427
"""Property s3_endpoint_url."""

awswrangler/_data_types.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,33 +342,35 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
342342
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")
343343

344344

345-
def athena2pandas(dtype: str) -> str: # pylint: disable=too-many-branches,too-many-return-statements
345+
def athena2pandas(
346+
dtype: str, dtype_backend: Optional[str] = None
347+
) -> str: # pylint: disable=too-many-branches,too-many-return-statements
346348
"""Athena to Pandas data types conversion."""
347349
dtype = dtype.lower()
348350
if dtype == "tinyint":
349-
return "Int8"
351+
return "Int8" if dtype_backend != "pyarrow" else "int8[pyarrow]"
350352
if dtype == "smallint":
351-
return "Int16"
353+
return "Int16" if dtype_backend != "pyarrow" else "int16[pyarrow]"
352354
if dtype in ("int", "integer"):
353-
return "Int32"
355+
return "Int32" if dtype_backend != "pyarrow" else "int32[pyarrow]"
354356
if dtype == "bigint":
355-
return "Int64"
357+
return "Int64" if dtype_backend != "pyarrow" else "int64[pyarrow]"
356358
if dtype in ("float", "real"):
357-
return "float32"
359+
return "float32" if dtype_backend != "pyarrow" else "double[pyarrow]"
358360
if dtype == "double":
359-
return "float64"
361+
return "float64" if dtype_backend != "pyarrow" else "double[pyarrow]"
360362
if dtype == "boolean":
361-
return "boolean"
363+
return "boolean" if dtype_backend != "pyarrow" else "bool[pyarrow]"
362364
if (dtype == "string") or dtype.startswith("char") or dtype.startswith("varchar"):
363-
return "string"
365+
return "string" if dtype_backend != "pyarrow" else "string[pyarrow]"
364366
if dtype in ("timestamp", "timestamp with time zone"):
365-
return "datetime64"
367+
return "datetime64" if dtype_backend != "pyarrow" else "date64[pyarrow]"
366368
if dtype == "date":
367-
return "date"
369+
return "date" if dtype_backend != "pyarrow" else "date32[pyarrow]"
368370
if dtype.startswith("decimal"):
369-
return "decimal"
371+
return "decimal" if dtype_backend != "pyarrow" else "double[pyarrow]"
370372
if dtype in ("binary", "varbinary"):
371-
return "bytes"
373+
return "bytes" if dtype_backend != "pyarrow" else "binary[pyarrow]"
372374
if dtype in ("array", "row", "map"):
373375
return "object"
374376
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")
@@ -465,6 +467,22 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
465467
return None
466468

467469

470+
def pyarrow2pyarrow_backed_pandas_extension( # pylint: disable=too-many-branches,too-many-return-statements
471+
dtype: pa.DataType,
472+
) -> Optional[pd.api.extensions.ExtensionDtype]:
473+
"""Pyarrow to Pandas PyArrow-backed data types conversion."""
474+
return pd.ArrowDtype(dtype)
475+
476+
477+
def get_pyarrow2pandas_type_mapper(
478+
dtype_backend: Optional[str] = None,
479+
) -> Callable[[pa.DataType], Optional[pd.api.extensions.ExtensionDtype]]:
480+
if dtype_backend == "pyarrow":
481+
return pyarrow2pyarrow_backed_pandas_extension
482+
483+
return pyarrow2pandas_extension
484+
485+
468486
@engine.dispatch_on_engine
469487
def pyarrow_types_from_pandas( # pylint: disable=too-many-branches,too-many-statements
470488
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False
@@ -550,14 +568,16 @@ def pyarrow_types_from_pandas( # pylint: disable=too-many-branches,too-many-sta
550568
return columns_types
551569

552570

553-
def pyarrow2pandas_defaults(use_threads: Union[bool, int], kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
571+
def pyarrow2pandas_defaults(
572+
use_threads: Union[bool, int], kwargs: Optional[Dict[str, Any]] = None, dtype_backend: Optional[str] = None
573+
) -> Dict[str, Any]:
554574
"""Return Pyarrow to Pandas default dictionary arguments."""
555575
default_kwargs = {
556576
"use_threads": use_threads,
557577
"split_blocks": True,
558578
"self_destruct": True,
559579
"ignore_metadata": False,
560-
"types_mapper": pyarrow2pandas_extension,
580+
"types_mapper": get_pyarrow2pandas_type_mapper(dtype_backend),
561581
}
562582
if kwargs:
563583
default_kwargs.update(kwargs)
@@ -685,7 +705,9 @@ def athena_types_from_pyarrow_schema(
685705
return columns_types, partitions_types
686706

687707

688-
def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd.DataFrame:
708+
def cast_pandas_with_athena_types(
709+
df: pd.DataFrame, dtype: Dict[str, str], dtype_backend: Optional[str] = None
710+
) -> pd.DataFrame:
689711
"""Cast columns in a Pandas DataFrame."""
690712
mutability_ensured: bool = False
691713
for col, athena_type in dtype.items():
@@ -695,7 +717,7 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
695717
and (athena_type.startswith("struct") is False)
696718
and (athena_type.startswith("map") is False)
697719
):
698-
desired_type: str = athena2pandas(dtype=athena_type)
720+
desired_type: str = athena2pandas(dtype=athena_type, dtype_backend=dtype_backend)
699721
current_type: str = _normalize_pandas_dtype_name(dtype=str(df[col].dtypes))
700722
if desired_type != current_type: # Needs conversion
701723
_logger.debug("current_type: %s -> desired_type: %s", current_type, desired_type)

awswrangler/_databases.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import boto3
99
import pyarrow as pa
10+
from typing_extensions import Literal
1011

1112
import awswrangler.pandas as pd
1213
from awswrangler import _data_types, _utils, exceptions, oracle, secretsmanager
@@ -153,6 +154,7 @@ def _records2df(
153154
safe: bool,
154155
dtype: Optional[Dict[str, pa.DataType]],
155156
timestamp_as_object: bool,
157+
dtype_backend: Literal["numpy_nullable", "pyarrow"],
156158
) -> pd.DataFrame:
157159
arrays: List[pa.Array] = []
158160
for col_values, col_name in zip(tuple(zip(*records)), cols_names): # Transposing
@@ -183,7 +185,7 @@ def _records2df(
183185
self_destruct=True,
184186
integer_object_nulls=False,
185187
date_as_object=True,
186-
types_mapper=_data_types.pyarrow2pandas_extension,
188+
types_mapper=_data_types.get_pyarrow2pandas_type_mapper(dtype_backend=dtype_backend),
187189
safe=safe,
188190
timestamp_as_object=timestamp_as_object,
189191
)
@@ -207,6 +209,7 @@ def _iterate_results(
207209
safe: bool,
208210
dtype: Optional[Dict[str, pa.DataType]],
209211
timestamp_as_object: bool,
212+
dtype_backend: Literal["numpy_nullable", "pyarrow"],
210213
) -> Iterator[pd.DataFrame]:
211214
with con.cursor() as cursor:
212215
cursor.execute(*cursor_args)
@@ -230,6 +233,7 @@ def _iterate_results(
230233
safe=safe,
231234
dtype=dtype,
232235
timestamp_as_object=timestamp_as_object,
236+
dtype_backend=dtype_backend,
233237
)
234238

235239

@@ -240,6 +244,7 @@ def _fetch_all_results(
240244
dtype: Optional[Dict[str, pa.DataType]] = None,
241245
safe: bool = True,
242246
timestamp_as_object: bool = False,
247+
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "pyarrow",
243248
) -> pd.DataFrame:
244249
with con.cursor() as cursor:
245250
cursor.execute(*cursor_args)
@@ -259,6 +264,7 @@ def _fetch_all_results(
259264
dtype=dtype,
260265
safe=safe,
261266
timestamp_as_object=timestamp_as_object,
267+
dtype_backend=dtype_backend,
262268
)
263269

264270

@@ -272,6 +278,7 @@ def read_sql_query(
272278
dtype: Optional[Dict[str, pa.DataType]] = ...,
273279
safe: bool = ...,
274280
timestamp_as_object: bool = ...,
281+
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
275282
) -> pd.DataFrame:
276283
...
277284

@@ -287,6 +294,7 @@ def read_sql_query(
287294
dtype: Optional[Dict[str, pa.DataType]] = ...,
288295
safe: bool = ...,
289296
timestamp_as_object: bool = ...,
297+
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
290298
) -> Iterator[pd.DataFrame]:
291299
...
292300

@@ -302,6 +310,7 @@ def read_sql_query(
302310
dtype: Optional[Dict[str, pa.DataType]] = ...,
303311
safe: bool = ...,
304312
timestamp_as_object: bool = ...,
313+
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
305314
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
306315
...
307316

@@ -315,6 +324,7 @@ def read_sql_query(
315324
dtype: Optional[Dict[str, pa.DataType]] = None,
316325
safe: bool = True,
317326
timestamp_as_object: bool = False,
327+
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
318328
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
319329
"""Read SQL Query (generic)."""
320330
args = _convert_params(sql, params)
@@ -327,6 +337,7 @@ def read_sql_query(
327337
dtype=dtype,
328338
safe=safe,
329339
timestamp_as_object=timestamp_as_object,
340+
dtype_backend=dtype_backend,
330341
)
331342

332343
return _iterate_results(
@@ -337,6 +348,7 @@ def read_sql_query(
337348
dtype=dtype,
338349
safe=safe,
339350
timestamp_as_object=timestamp_as_object,
351+
dtype_backend=dtype_backend,
340352
)
341353
except Exception as ex:
342354
con.rollback()

awswrangler/_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def inner(*args: Any, **kwargs: Any) -> Any:
161161
set([key for key, value in kwargs.items() if value is not None])
162162
)
163163

164+
# Allow kwargs that didn't modify the default value
165+
passed_unsupported_kwargs = {
166+
key for key in passed_unsupported_kwargs if kwargs[key] != signature.parameters[key].default
167+
}
168+
164169
if condition_fn() and len(passed_unsupported_kwargs) > 0:
165170
raise exceptions.InvalidArgument(f"{message} `{', '.join(passed_unsupported_kwargs)}`.")
166171

0 commit comments

Comments
 (0)