Skip to content

Commit 1a38261

Browse files
committed
Improving SQL on Athena iterator.
1 parent 1b50043 commit 1a38261

File tree

9 files changed

+462
-74
lines changed

9 files changed

+462
-74
lines changed

awswrangler/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__title__ = "awswrangler"
22
__description__ = "Utility belt to handle data on AWS."
3-
__version__ = "0.0b20"
3+
__version__ = "0.0b27"
44
__license__ = "Apache License 2.0"

awswrangler/athena.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from time import sleep
22
import logging
33

4+
from awswrangler.exceptions import UnsupportedType
5+
46
logger = logging.getLogger(__name__)
57

68
QUERY_WAIT_POLLING_DELAY = 0.2 # MILLISECONDS
@@ -12,6 +14,43 @@ def __init__(self, session):
1214
self._client_athena = session.boto3_session.client(
1315
service_name="athena", config=session.botocore_config)
1416

17+
def get_query_columns_metadata(self, query_execution_id):
18+
response = self._client_athena.get_query_results(
19+
QueryExecutionId=query_execution_id, MaxResults=1)
20+
col_info = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
21+
return {x["Name"]: x["Type"] for x in col_info}
22+
23+
@staticmethod
24+
def _type_athena2pandas(dtype):
25+
dtype = dtype.lower()
26+
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
27+
return "Int64"
28+
elif dtype in ["float", "double", "real"]:
29+
return "float64"
30+
elif dtype == "boolean":
31+
return "bool"
32+
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
33+
return "object"
34+
elif dtype in ["timestamp", "date"]:
35+
return "datetime64"
36+
else:
37+
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
38+
39+
def get_query_dtype(self, query_execution_id):
40+
cols_metadata = self.get_query_columns_metadata(
41+
query_execution_id=query_execution_id)
42+
dtype = {}
43+
parse_dates = []
44+
for col_name, col_type in cols_metadata.items():
45+
ptype = Athena._type_athena2pandas(dtype=col_type)
46+
if ptype == "datetime64":
47+
parse_dates.append(col_name)
48+
else:
49+
dtype[col_name] = ptype
50+
logger.debug(f"dtype: {dtype}")
51+
logger.debug(f"parse_dates: {parse_dates}")
52+
return dtype, parse_dates
53+
1554
def run_query(self, query, database, s3_output):
1655
response = self._client_athena.start_query_execution(
1756
QueryString=query,

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,9 @@ class EmptyS3Object(Exception):
3030
pass
3131

3232

33+
class LineTerminatorNotFound(Exception):
34+
pass
35+
36+
3337
class MissingBatchDetected(Exception):
3438
pass

awswrangler/glue.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from math import ceil
22
import re
33
import logging
4+
from datetime import datetime, date
45

56
from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat
67

@@ -13,7 +14,7 @@ def __init__(self, session):
1314
self._client_glue = session.boto3_session.client(
1415
service_name="glue", config=session.botocore_config)
1516

16-
def get_table_dtypes(self, database, table):
17+
def get_table_athena_types(self, database, table):
1718
"""
1819
Get all columns names and the related data types
1920
:param database: Glue database's name
@@ -37,24 +38,44 @@ def get_table_python_types(self, database, table):
3738
:param table: Glue table's name
3839
:return: A dictionary as {"col name": "col python type"}
3940
"""
40-
dtypes = self.get_table_dtypes(database=database, table=table)
41+
dtypes = self.get_table_athena_types(database=database, table=table)
4142
return {k: Glue._type_athena2python(v) for k, v in dtypes.items()}
4243

44+
@staticmethod
45+
def _type_pandas2athena(dtype):
46+
dtype = dtype.lower()
47+
if dtype == "int32":
48+
return "int"
49+
elif dtype in ["int64", "Int64"]:
50+
return "bigint"
51+
elif dtype == "float32":
52+
return "float"
53+
elif dtype == "float64":
54+
return "double"
55+
elif dtype == "bool":
56+
return "boolean"
57+
elif dtype == "object" and isinstance(dtype, str):
58+
return "string"
59+
elif dtype[:10] == "datetime64":
60+
return "timestamp"
61+
else:
62+
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
63+
4364
@staticmethod
4465
def _type_athena2python(dtype):
4566
dtype = dtype.lower()
46-
if dtype == "int":
47-
return int
48-
elif dtype == "bigint":
67+
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
4968
return int
50-
elif dtype == "float":
51-
return float
52-
elif dtype == "double":
69+
elif dtype in ["float", "double", "real"]:
5370
return float
5471
elif dtype == "boolean":
5572
return bool
56-
elif dtype == "string":
73+
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
5774
return str
75+
elif dtype == "timestamp":
76+
return datetime
77+
elif dtype == "date":
78+
return date
5879
else:
5980
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
6081

@@ -157,6 +178,7 @@ def _build_schema(dataframe,
157178
partition_cols,
158179
preserve_index,
159180
cast_columns=None):
181+
print(f"dataframe.dtypes:\n{dataframe.dtypes}")
160182
if not partition_cols:
161183
partition_cols = []
162184
schema_built = []
@@ -180,26 +202,6 @@ def _build_schema(dataframe,
180202
logger.debug(f"schema_built:\n{schema_built}")
181203
return schema_built
182204

183-
@staticmethod
184-
def _type_pandas2athena(dtype):
185-
dtype = dtype.lower()
186-
if dtype == "int32":
187-
return "int"
188-
elif dtype == "int64":
189-
return "bigint"
190-
elif dtype == "float32":
191-
return "float"
192-
elif dtype == "float64":
193-
return "double"
194-
elif dtype == "bool":
195-
return "boolean"
196-
elif dtype == "object" and isinstance(dtype, str):
197-
return "string"
198-
elif dtype[:10] == "datetime64":
199-
return "timestamp"
200-
else:
201-
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
202-
203205
@staticmethod
204206
def _parse_table_name(path):
205207
if path[-1] == "/":

awswrangler/pandas.py

Lines changed: 106 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import multiprocessing as mp
33
import logging
44
from math import floor
5+
import copy
6+
import csv
57

68
import pandas
79
import pyarrow
810
from pyarrow import parquet
911

10-
from awswrangler.exceptions import UnsupportedWriteMode, UnsupportedFileFormat, AthenaQueryError, EmptyS3Object
12+
from awswrangler.exceptions import UnsupportedWriteMode, UnsupportedFileFormat, AthenaQueryError, EmptyS3Object, LineTerminatorNotFound
1113
from awswrangler.utils import calculate_bounders
1214
from awswrangler import s3
1315

@@ -41,7 +43,7 @@ def read_csv(
4143
sep=",",
4244
lineterminator="\n",
4345
quotechar='"',
44-
quoting=0,
46+
quoting=csv.QUOTE_MINIMAL,
4547
escapechar=None,
4648
parse_dates=False,
4749
infer_datetime_format=False,
@@ -119,7 +121,7 @@ def _read_csv_iterator(
119121
sep=",",
120122
lineterminator="\n",
121123
quotechar='"',
122-
quoting=0,
124+
quoting=csv.QUOTE_MINIMAL,
123125
escapechar=None,
124126
parse_dates=False,
125127
infer_datetime_format=False,
@@ -177,38 +179,38 @@ def _read_csv_iterator(
177179
bounders_len = len(bounders)
178180
count = 0
179181
forgotten_bytes = 0
180-
cols_names = None
181182
for ini, end in bounders:
182183
count += 1
184+
183185
ini -= forgotten_bytes
184186
end -= 1 # Range is inclusive, contrary to Python's List
185187
bytes_range = "bytes={}-{}".format(ini, end)
186188
logger.debug(f"bytes_range: {bytes_range}")
187189
body = client_s3.get_object(Bucket=bucket_name, Key=key_path, Range=bytes_range)["Body"]\
188190
.read()\
189-
.decode(encoding, errors="ignore")
191+
.decode("utf-8")
190192
chunk_size = len(body)
191193
logger.debug(f"chunk_size: {chunk_size}")
192-
if body[0] == lineterminator:
193-
first_char = 1
194-
else:
195-
first_char = 0
196-
if (count == 1) and (count == bounders_len):
197-
last_break_line_idx = chunk_size
198-
elif count == 1: # first chunk
199-
last_break_line_idx = body.rindex(lineterminator)
200-
forgotten_bytes = chunk_size - last_break_line_idx
194+
195+
if count == 1: # first chunk
196+
last_char = Pandas._find_terminator(
197+
body=body,
198+
quoting=quoting,
199+
quotechar=quotechar,
200+
lineterminator=lineterminator)
201+
forgotten_bytes = len(body[last_char:].encode("utf-8"))
201202
elif count == bounders_len: # Last chunk
202-
header = None
203-
names = cols_names
204-
last_break_line_idx = chunk_size
203+
last_char = chunk_size
205204
else:
206-
header = None
207-
names = cols_names
208-
last_break_line_idx = body.rindex(lineterminator)
209-
forgotten_bytes = chunk_size - last_break_line_idx
205+
last_char = Pandas._find_terminator(
206+
body=body,
207+
quoting=quoting,
208+
quotechar=quotechar,
209+
lineterminator=lineterminator)
210+
forgotten_bytes = len(body[last_char:].encode("utf-8"))
211+
210212
df = pandas.read_csv(
211-
StringIO(body[first_char:last_break_line_idx]),
213+
StringIO(body[:last_char]),
212214
header=header,
213215
names=names,
214216
sep=sep,
@@ -223,7 +225,64 @@ def _read_csv_iterator(
223225
)
224226
yield df
225227
if count == 1: # first chunk
226-
cols_names = df.columns
228+
names = df.columns
229+
header = None
230+
231+
@staticmethod
232+
def _find_terminator(body, quoting, quotechar, lineterminator):
233+
"""
234+
Find for any suspicious of line terminator (From end to start)
235+
:param body: String
236+
:param quoting: Same as pandas.read_csv()
237+
:param quotechar: Same as pandas.read_csv()
238+
:param lineterminator: Same as pandas.read_csv()
239+
:return: The index of the suspect line terminator
240+
"""
241+
try:
242+
if quoting == csv.QUOTE_ALL:
243+
index = body.rindex(lineterminator)
244+
while True:
245+
i = 0
246+
while True:
247+
i += 1
248+
if index + i <= len(body) - 1:
249+
c = body[index + i]
250+
if c == ",":
251+
pass
252+
elif c == quotechar:
253+
right = True
254+
break
255+
else:
256+
right = False
257+
break
258+
else:
259+
right = True
260+
break
261+
i = 0
262+
while True:
263+
i += 1
264+
if index - i >= 0:
265+
c = body[index - i]
266+
if c == ",":
267+
pass
268+
elif c == quotechar:
269+
left = True
270+
break
271+
else:
272+
left = False
273+
break
274+
else:
275+
left = True
276+
break
277+
278+
if right and left:
279+
break
280+
index = body[:index].rindex(lineterminator)
281+
else:
282+
index = body.rindex(lineterminator)
283+
except ValueError:
284+
raise LineTerminatorNotFound()
285+
return index
227286

228287
@staticmethod
229288
def _read_csv_once(
@@ -293,7 +352,7 @@ def read_sql_athena(self,
293352
Executes any SQL query on AWS Athena and return a Dataframe of the result.
294353
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
295354
:param sql: SQL Query
296-
:param database: Glue/Athena Databease
355+
:param database: Glue/Athena Database
297356
:param s3_output: AWS S3 path
298357
:param max_result_size: Max number of bytes on each request to S3
299358
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
@@ -318,8 +377,14 @@ def read_sql_athena(self,
318377
message_error = f"Query error: {reason}"
319378
raise AthenaQueryError(message_error)
320379
else:
380+
dtype, parse_dates = self._session.athena.get_query_dtype(
381+
query_execution_id=query_execution_id)
321382
path = f"{s3_output}{query_execution_id}.csv"
322-
ret = self.read_csv(path=path, max_result_size=max_result_size)
383+
ret = self.read_csv(path=path,
384+
dtype=dtype,
385+
parse_dates=parse_dates,
386+
quoting=csv.QUOTE_ALL,
387+
max_result_size=max_result_size)
323388
return ret
324389

325390
def to_csv(
@@ -623,11 +688,18 @@ def write_csv_dataframe(dataframe, path, preserve_index, fs):
623688
f.write(csv_buffer)
624689

625690
@staticmethod
626-
def write_parquet_dataframe(dataframe,
627-
path,
628-
preserve_index,
629-
fs,
630-
cast_columns=None):
691+
def write_parquet_dataframe(dataframe, path, preserve_index, fs,
692+
cast_columns):
693+
if not cast_columns:
694+
cast_columns = {}
695+
casted_in_pandas = []
696+
dtypes = copy.deepcopy(dataframe.dtypes.to_dict())
697+
for name, dtype in dtypes.items():
698+
if str(dtype) == "Int64":
699+
dataframe[name] = dataframe[name].astype("float64")
700+
casted_in_pandas.append(name)
701+
cast_columns[name] = "int64"
702+
logger.debug(f"Casting column {name} Int64 to float64")
631703
table = pyarrow.Table.from_pandas(df=dataframe,
632704
preserve_index=preserve_index,
633705
safe=False)
@@ -636,13 +708,15 @@ def write_parquet_dataframe(dataframe,
636708
col_index = table.column_names.index(col_name)
637709
table = table.set_column(col_index,
638710
table.column(col_name).cast(dtype))
639-
logger.debug(f"{col_name} - {col_index} - {dtype}")
640-
logger.debug(f"table.schema:\n{table.schema}")
711+
logger.debug(
712+
f"Casting column {col_name} ({col_index}) to {dtype}")
641713
with fs.open(path, "wb") as f:
642714
parquet.write_table(table,
643715
f,
644716
coerce_timestamps="ms",
645717
flavor="spark")
718+
for col in casted_in_pandas:
719+
dataframe[col] = dataframe[col].astype("Int64")
646720

647721
def to_redshift(
648722
self,

0 commit comments

Comments
 (0)