Skip to content

Commit eae3f91

Browse files
committed
Add new ctas_approach for Pandas.read_sql_athena()
1 parent 5e3f89f commit eae3f91

File tree

8 files changed

+384
-43
lines changed

8 files changed

+384
-43
lines changed

awswrangler/athena.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import re
55
import unicodedata
6+
from datetime import datetime, date
67

78
from awswrangler.data_types import athena2python
89
from awswrangler.exceptions import QueryFailed, QueryCancelled
@@ -162,8 +163,15 @@ def _rows2row(rows: List[Dict[str, List[Dict[str, str]]]],
162163
vals_varchar: List[Optional[str]] = [x["VarCharValue"] if x else None for x in row["Data"]]
163164
data: Dict[str, Any] = {}
164165
for (name, ptype), val in zip(python_types, vals_varchar):
165-
if ptype is not None:
166-
data[name] = ptype(val)
166+
if val is not None:
167+
if ptype is None:
168+
data[name] = None
169+
elif ptype == date:
170+
data[name] = date(*[int(y) for y in val.split("-")])
171+
elif ptype == datetime:
172+
data[name] = datetime.strptime(val + "000", "%Y-%m-%d %H:%M:%S.%f")
173+
else:
174+
data[name] = ptype(val)
167175
else:
168176
data[name] = None
169177
yield data

awswrangler/data_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Tuple, Dict, Callable, Optional
22
import logging
33
from datetime import datetime, date
4+
from decimal import Decimal
45

56
import pyarrow as pa # type: ignore
67
import pandas as pd # type: ignore
@@ -74,6 +75,8 @@ def athena2python(dtype: str) -> Optional[type]:
7475
return date
7576
elif dtype == "unknown":
7677
return None
78+
elif dtype == "decimal":
79+
return Decimal
7780
else:
7881
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
7982

awswrangler/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,11 @@ class ApiError(Exception):
8080

8181
class InvalidCompression(Exception):
8282
pass
83+
84+
85+
class InvalidTable(Exception):
86+
pass
87+
88+
89+
class InvalidParameters(Exception):
90+
pass

awswrangler/glue.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from typing import Dict, Optional
12
from math import ceil
23
import re
34
import logging
45

56
from awswrangler import data_types
67
from awswrangler.athena import Athena
7-
from awswrangler.exceptions import UnsupportedFileFormat, InvalidSerDe, ApiError, UnsupportedType, UndetectedType
8+
from awswrangler.exceptions import UnsupportedFileFormat, InvalidSerDe, ApiError, UnsupportedType, UndetectedType, InvalidTable, InvalidArguments
89

910
logger = logging.getLogger(__name__)
1011

@@ -85,7 +86,11 @@ def metadata_to_glue(self,
8586
compression=compression,
8687
extra_args=extra_args)
8788

88-
def delete_table_if_exists(self, database, table):
89+
def delete_table_if_exists(self, table: str = None, database: Optional[str] = None):
90+
if database is None and self._session.athena_database is not None:
91+
database = self._session.athena_database
92+
if database is None:
93+
raise InvalidArguments("You must pass a valid database or have one defined in your Session!")
8994
try:
9095
self._client_glue.delete_table(DatabaseName=database, Name=table)
9196
except self._client_glue.exceptions.EntityNotFoundException:
@@ -372,3 +377,16 @@ def _parse_partitions_tuples(objects_paths, partition_cols):
372377
@staticmethod
373378
def _parse_partition_values(path, partition_cols):
374379
return [re.search(f"/{col}=(.*?)/", path).group(1) for col in partition_cols]
380+
381+
def get_table_location(self, database: str, table: str):
382+
"""
383+
Get table's location on Glue catalog
384+
385+
:param database: Database name
386+
:param table: table name
387+
"""
388+
res: Dict = self._client_glue.get_table(DatabaseName=database, Name=table)
389+
try:
390+
return res["Table"]["StorageDescriptor"]["Location"]
391+
except KeyError:
392+
raise InvalidTable(f"{database}.{table}")

awswrangler/pandas.py

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from math import floor
66
import copy
77
import csv
8-
from datetime import datetime
8+
from datetime import datetime, date
99
from decimal import Decimal
1010
from ast import literal_eval
1111

@@ -18,7 +18,8 @@
1818

1919
from awswrangler import data_types
2020
from awswrangler.exceptions import (UnsupportedWriteMode, UnsupportedFileFormat, AthenaQueryError, EmptyS3Object,
21-
LineTerminatorNotFound, EmptyDataframe, InvalidSerDe, InvalidCompression)
21+
LineTerminatorNotFound, EmptyDataframe, InvalidSerDe, InvalidCompression,
22+
InvalidParameters)
2223
from awswrangler.utils import calculate_bounders
2324
from awswrangler import s3
2425
from awswrangler.athena import Athena
@@ -495,29 +496,100 @@ def read_sql_athena(self,
495496
sql: str,
496497
database: Optional[str] = None,
497498
s3_output: Optional[str] = None,
498-
max_result_size: Optional[int] = None,
499499
workgroup: Optional[str] = None,
500500
encryption: Optional[str] = None,
501-
kms_key: Optional[str] = None):
501+
kms_key: Optional[str] = None,
502+
ctas_approach: bool = True,
503+
procs_cpu_bound: Optional[int] = None,
504+
max_result_size: Optional[int] = None):
502505
"""
503506
Executes any SQL query on AWS Athena and return a Dataframe of the result.
504-
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
507+
There are two approaches to be defined through ctas_approach parameter:
508+
1 - ctas_approach True (Default):
509+
Wrap the query with a CTAS and then reads the table data as parquet directly from s3.
510+
PROS: Faster and has a better handle of nested types
511+
CONS: Can't use max_result_size.
512+
2 - ctas_approach False:
513+
Does a regular query on Athena and parse the regular CSV result on s3
514+
PROS: Accepts max_result_size.
515+
CONS: Slower (But stills faster than other libraries that uses the Athena API) and does not handle nested types so well
516+
517+
P.S. If ctas_approach is False and max_result_size is passed, then a iterator of Dataframes is returned.
505518
P.S.S. All default values will be inherited from the Session()
506519
507520
:param sql: SQL Query
508521
:param database: Glue/Athena Database
509522
:param s3_output: AWS S3 path
510-
:param max_result_size: Max number of bytes on each request to S3
511523
:param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
512524
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
513525
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
514-
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
526+
:param ctas_approach: Wraps the query with a CTAS
527+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
528+
:param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False)
529+
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed
515530
"""
531+
if ctas_approach is True and max_result_size is not None:
532+
raise InvalidParameters("ctas_approach can't use max_result_size!")
516533
if s3_output is None:
517534
if self._session.athena_s3_output is not None:
518535
s3_output = self._session.athena_s3_output
519536
else:
520537
s3_output = self._session.athena.create_athena_bucket()
538+
if ctas_approach is False:
539+
return self._read_sql_athena_regular(sql=sql,
540+
database=database,
541+
s3_output=s3_output,
542+
workgroup=workgroup,
543+
encryption=encryption,
544+
kms_key=kms_key,
545+
max_result_size=max_result_size)
546+
else:
547+
return self._read_sql_athena_ctas(sql=sql,
548+
database=database,
549+
s3_output=s3_output,
550+
workgroup=workgroup,
551+
encryption=encryption,
552+
kms_key=kms_key,
553+
procs_cpu_bound=procs_cpu_bound)
554+
555+
def _read_sql_athena_ctas(self,
556+
sql: str,
557+
s3_output: str,
558+
database: Optional[str] = None,
559+
workgroup: Optional[str] = None,
560+
encryption: Optional[str] = None,
561+
kms_key: Optional[str] = None,
562+
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
563+
guid: str = pa.compat.guid()
564+
name: str = f"temp_table_{guid}"
565+
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
566+
path: str = f"{s3_output}/{name}"
567+
query: str = f"CREATE TABLE {name}\n" \
568+
f"WITH(\n" \
569+
f" format = 'Parquet',\n" \
570+
f" parquet_compression = 'SNAPPY',\n" \
571+
f" external_location = '{path}'\n" \
572+
f") AS\n" \
573+
f"{sql}"
574+
logger.debug(f"query: {query}")
575+
query_id: str = self._session.athena.run_query(query=query,
576+
database=database,
577+
s3_output=s3_output,
578+
workgroup=workgroup,
579+
encryption=encryption,
580+
kms_key=kms_key)
581+
self._session.athena.wait_query(query_execution_id=query_id)
582+
self._session.glue.delete_table_if_exists(database=database, table=name)
583+
return self.read_parquet(path=path, procs_cpu_bound=procs_cpu_bound)
584+
585+
def _read_sql_athena_regular(self,
586+
sql: str,
587+
s3_output: str,
588+
database: Optional[str] = None,
589+
workgroup: Optional[str] = None,
590+
encryption: Optional[str] = None,
591+
kms_key: Optional[str] = None,
592+
max_result_size: Optional[int] = None):
521593
query_execution_id: str = self._session.athena.run_query(query=sql,
522594
database=database,
523595
s3_output=s3_output,
@@ -542,7 +614,10 @@ def read_sql_athena(self,
542614
if max_result_size is None:
543615
if len(ret.index) > 0:
544616
for col in parse_dates:
545-
ret[col] = ret[col].dt.date.replace(to_replace={pd.NaT: None})
617+
if str(ret[col].dtype) == "object":
618+
ret[col] = ret[col].apply(lambda x: date(*[int(y) for y in x.split("-")]))
619+
else:
620+
ret[col] = ret[col].dt.date.replace(to_replace={pd.NaT: None})
546621
return ret
547622
else:
548623
return Pandas._apply_dates_to_generator(generator=ret, parse_dates=parse_dates)
@@ -1151,5 +1226,29 @@ def read_parquet(self,
11511226
use_threads: bool = True if procs_cpu_bound > 1 else False
11521227
fs: S3FileSystem = s3.get_fs(session_primitives=self._session.primitives)
11531228
fs = pa.filesystem._ensure_filesystem(fs)
1154-
return pq.read_table(source=path, columns=columns, filters=filters,
1155-
filesystem=fs).to_pandas(use_threads=use_threads)
1229+
table = pq.read_table(source=path, columns=columns, filters=filters, filesystem=fs, use_threads=use_threads)
1230+
# Check if we lose some integer during the conversion (Happens when has some null value)
1231+
integers = [field.name for field in table.schema if str(field.type).startswith("int")]
1232+
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True)
1233+
for c in integers:
1234+
if not str(df[c].dtype).startswith("int"):
1235+
df[c] = df[c].astype("Int64")
1236+
return df
1237+
1238+
def read_table(self,
1239+
database: str,
1240+
table: str,
1241+
columns: Optional[List[str]] = None,
1242+
filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None,
1243+
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
1244+
"""
1245+
Read PARQUET table from S3 using the Glue Catalog location skipping Athena's necessity
1246+
1247+
:param database: Database name
1248+
:param table: table name
1249+
:param columns: Names of columns to read from the file
1250+
:param filters: List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
1251+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
1252+
"""
1253+
path: str = self._session.glue.get_table_location(database=database, table=table)
1254+
return self.read_parquet(path=path, columns=columns, filters=filters, procs_cpu_bound=procs_cpu_bound)

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy~=1.17.4
22
pandas~=0.25.3
33
pyarrow~=0.15.1
4-
botocore~=1.13.34
5-
boto3~=1.10.34
4+
botocore~=1.13.35
5+
boto3~=1.10.35
66
s3fs~=0.4.0
77
tenacity~=6.0.0
88
pg8000~=1.13.2

testing/test_awswrangler/test_athena.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import logging
2+
from datetime import datetime, date
3+
from decimal import Decimal
24

35
import pytest
46
import boto3
7+
import pandas as pd
58

69
from awswrangler import Session
710
from awswrangler.exceptions import QueryCancelled, QueryFailed
@@ -193,3 +196,46 @@ def test_query(session, database):
193196
assert row["_col2"] == 2.0
194197
assert row["_col3"] is True
195198
assert row["_col4"] is None
199+
200+
201+
def test_query2(session, bucket, database):
202+
df = pd.DataFrame({
203+
"id": [1, 2, 3],
204+
"col_date": [date(194, 1, 12), None, date(2049, 12, 30)],
205+
"col_timestamp": [datetime(194, 1, 12, 1, 1, 1, 1000), None,
206+
datetime(2049, 12, 30, 1, 1, 1, 1000)],
207+
"col_string": ["foo", None, "boo"],
208+
"col_double": [1.1, None, 2.2],
209+
"col_decimal": [Decimal((0, (1, 9, 9), -2)), None,
210+
Decimal((0, (1, 9, 0), -2))],
211+
"col_int": [1, None, 2]
212+
})
213+
path = f"s3://{bucket}/test_query2/"
214+
session.pandas.to_parquet(dataframe=df,
215+
database=database,
216+
table="test",
217+
path=path,
218+
mode="overwrite",
219+
preserve_index=False)
220+
for row in session.athena.query(query="SELECT * FROM test", database=database):
221+
if row["id"] == 1:
222+
assert row["col_date"] == date(194, 1, 12)
223+
assert row["col_timestamp"] == datetime(194, 1, 12, 1, 1, 1, 1000)
224+
assert row["col_string"] == "foo"
225+
assert row["col_double"] == 1.1
226+
assert row["col_decimal"] == Decimal((0, (1, 9, 9), -2))
227+
assert row["col_int"] == 1
228+
elif row["id"] == 2:
229+
assert row["col_date"] is None
230+
assert row["col_timestamp"] is None
231+
assert row["col_string"] is None
232+
assert row["col_double"] is None
233+
assert row["col_decimal"] is None
234+
assert row["col_int"] is None
235+
else:
236+
assert row["col_date"] == date(2049, 12, 30)
237+
assert row["col_timestamp"] == datetime(2049, 12, 30, 1, 1, 1, 1000)
238+
assert row["col_string"] == "boo"
239+
assert row["col_double"] == 2.2
240+
assert row["col_decimal"] == Decimal((0, (1, 9, 0), -2))
241+
assert row["col_int"] == 2

0 commit comments

Comments
 (0)