Skip to content

Commit a70e5aa

Browse files
committed
Improve postgres/redshift read. #427 #431
1 parent b1c1b79 commit a70e5aa

File tree

6 files changed

+60
-12
lines changed

6 files changed

+60
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
> An [AWS Professional Service](https://aws.amazon.com/professional-services/) open source initiative | [email protected]
88
9-
[![Release](https://img.shields.io/badge/release-1.9.6-brightgreen.svg)](https://pypi.org/project/awswrangler/)
9+
[![Release](https://img.shields.io/badge/release-1.10.0-brightgreen.svg)](https://pypi.org/project/awswrangler/)
1010
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-brightgreen.svg)](https://anaconda.org/conda-forge/awswrangler)
1111
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1212
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

awswrangler/__metadata__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
__title__: str = "awswrangler"
99
__description__: str = "Pandas on AWS."
10-
__version__: str = "1.9.6"
10+
__version__: str = "1.10.0"
1111
__license__: str = "Apache License 2.0"

awswrangler/db.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,22 @@ def _records2df(
187187
records: List[Tuple[Any]],
188188
cols_names: List[str],
189189
index: Optional[Union[str, List[str]]],
190-
dtype: Optional[Dict[str, pa.DataType]] = None,
190+
safe: bool,
191+
dtype: Optional[Dict[str, pa.DataType]],
191192
) -> pd.DataFrame:
192193
arrays: List[pa.Array] = []
193194
for col_values, col_name in zip(tuple(zip(*records)), cols_names): # Transposing
194195
if (dtype is None) or (col_name not in dtype):
195196
try:
196-
array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array
197+
array: pa.Array = pa.array(obj=col_values, safe=safe) # Creating Arrow array
197198
except pa.ArrowInvalid as ex:
198199
array = _data_types.process_not_inferred_array(ex, values=col_values) # Creating Arrow array
199200
else:
200-
array = pa.array(obj=col_values, type=dtype[col_name], safe=True) # Creating Arrow array with dtype
201+
try:
202+
array = pa.array(obj=col_values, type=dtype[col_name], safe=safe) # Creating Arrow array with dtype
203+
except pa.ArrowInvalid:
204+
array = pa.array(obj=col_values, safe=safe) # Creating Arrow array
205+
array = array.cast(target_type=dtype[col_name], safe=safe) # Casting
201206
arrays.append(array)
202207
table = pa.Table.from_arrays(arrays=arrays, names=cols_names) # Creating arrow Table
203208
df: pd.DataFrame = table.to_pandas( # Creating Pandas DataFrame
@@ -207,6 +212,7 @@ def _records2df(
207212
integer_object_nulls=False,
208213
date_as_object=True,
209214
types_mapper=_data_types.pyarrow2pandas_extension,
215+
safe=safe,
210216
)
211217
if index is not None:
212218
df.set_index(index, inplace=True)
@@ -218,13 +224,14 @@ def _iterate_cursor(
218224
chunksize: int,
219225
cols_names: List[str],
220226
index: Optional[Union[str, List[str]]],
221-
dtype: Optional[Dict[str, pa.DataType]] = None,
227+
safe: bool,
228+
dtype: Optional[Dict[str, pa.DataType]],
222229
) -> Iterator[pd.DataFrame]:
223230
while True:
224231
records = cursor.fetchmany(chunksize)
225232
if not records:
226233
break
227-
yield _records2df(records=records, cols_names=cols_names, index=index, dtype=dtype)
234+
yield _records2df(records=records, cols_names=cols_names, index=index, safe=safe, dtype=dtype)
228235

229236

230237
def _convert_params(sql: str, params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]]) -> List[Any]:
@@ -366,6 +373,7 @@ def read_sql_query(
366373
params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None,
367374
chunksize: Optional[int] = None,
368375
dtype: Optional[Dict[str, pa.DataType]] = None,
376+
safe: bool = True,
369377
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
370378
"""Return a DataFrame corresponding to the result set of the query string.
371379
@@ -395,6 +403,8 @@ def read_sql_query(
395403
dtype : Dict[str, pyarrow.DataType], optional
396404
Specifying the datatype for columns.
397405
The keys should be the column names and the values should be the PyArrow types.
406+
safe : bool
407+
Check for overflows or other unsafe data type conversions.
398408
399409
Returns
400410
-------
@@ -425,9 +435,11 @@ def read_sql_query(
425435
args = _convert_params(sql, params)
426436
cursor = _con.execute(*args)
427437
if chunksize is None:
428-
return _records2df(records=cursor.fetchall(), cols_names=cursor.keys(), index=index_col, dtype=dtype)
438+
return _records2df(
439+
records=cursor.fetchall(), cols_names=cursor.keys(), index=index_col, dtype=dtype, safe=safe
440+
)
429441
return _iterate_cursor(
430-
cursor=cursor, chunksize=chunksize, cols_names=cursor.keys(), index=index_col, dtype=dtype
442+
cursor=cursor, chunksize=chunksize, cols_names=cursor.keys(), index=index_col, dtype=dtype, safe=safe
431443
)
432444

433445

@@ -439,6 +451,7 @@ def read_sql_table(
439451
params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None,
440452
chunksize: Optional[int] = None,
441453
dtype: Optional[Dict[str, pa.DataType]] = None,
454+
safe: bool = True,
442455
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
443456
"""Return a DataFrame corresponding to the result set of the query string.
444457
@@ -471,6 +484,8 @@ def read_sql_table(
471484
dtype : Dict[str, pyarrow.DataType], optional
472485
Specifying the datatype for columns.
473486
The keys should be the column names and the values should be the PyArrow types.
487+
safe : bool
488+
Check for overflows or other unsafe data type conversions.
474489
475490
Returns
476491
-------
@@ -502,7 +517,9 @@ def read_sql_table(
502517
sql: str = f"SELECT * FROM {table}"
503518
else:
504519
sql = f"SELECT * FROM {schema}.{table}"
505-
return read_sql_query(sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype)
520+
return read_sql_query(
521+
sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe
522+
)
506523

507524

508525
def get_redshift_temp_engine(

docs/source/install.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Go to your Glue PySpark job and create a new *Job parameters* key/value:
5959

6060
To install a specific version, set the value for above Job parameter as follows:
6161

62-
* Value: ``awswrangler==1.9.6``
62+
* Value: ``awswrangler==1.10.0``
6363

6464
`Official Glue PySpark Reference <https://docs.aws.amazon.com/glue/latest/dg/reduced-start-times-spark-etl-jobs.html#reduced-start-times-new-features>`_
6565

tests/test_db.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,34 @@ def test_redshift_copy_extras(path, redshift_table, databases_parameters, use_th
704704
assert df.int16.sum() * num == df2.int16.sum()
705705
assert df.int32.sum() * num == df2.int32.sum()
706706
assert df.int64.sum() * num == df2.int64.sum()
707+
708+
709+
def test_redshift_decimal_cast(redshift_table):
710+
df = pd.DataFrame(
711+
{
712+
"col0": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
713+
"col1": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
714+
"col2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
715+
}
716+
)
717+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
718+
wr.db.to_sql(df, engine, name=redshift_table)
719+
df2 = wr.db.read_sql_table(
720+
schema="public", table=redshift_table, con=engine, dtype={"col0": "float32", "col1": "float64", "col2": "Int64"}
721+
)
722+
assert df2.dtypes.to_list() == ["float32", "float64", "Int64"]
723+
assert 3.88 <= df2.col0.sum() <= 3.89
724+
assert 3.88 <= df2.col1.sum() <= 3.89
725+
assert df2.col2.sum() == 2
726+
727+
728+
def test_postgresql_out_of_bound():
729+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-postgresql")
730+
sql = """
731+
SELECT TO_TIMESTAMP(
732+
'9999-12-31 9:30:20',
733+
'YYYY-MM-DD HH:MI:SS'
734+
)::timestamp without time zone;
735+
"""
736+
df = wr.db.read_sql_query(sql=sql, con=engine, safe=False)
737+
assert df.shape == (1, 1)

tests/test_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def test_metadata():
5-
assert wr.__version__ == "1.9.6"
5+
assert wr.__version__ == "1.10.0"
66
assert wr.__title__ == "awswrangler"
77
assert wr.__description__ == "Pandas on AWS."
88
assert wr.__license__ == "Apache License 2.0"

0 commit comments

Comments
 (0)