Skip to content

Commit 93b946b

Browse files
authored
Merge pull request #55 from awslabs/athena-query
Add Athena.query()
2 parents 6f57a36 + 63070ff commit 93b946b

File tree

4 files changed

+70
-6
lines changed

4 files changed

+70
-6
lines changed

awswrangler/athena.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from typing import Dict, List, Tuple, Optional, Any, Iterator
12
from time import sleep
23
import logging
34
import ast
45
import re
56
import unicodedata
67

7-
from awswrangler import data_types
8+
from awswrangler.data_types import athena2python, athena2pandas
89
from awswrangler.exceptions import QueryFailed, QueryCancelled
910

1011
logger = logging.getLogger(__name__)
@@ -30,7 +31,7 @@ def get_query_dtype(self, query_execution_id):
3031
parse_dates = []
3132
converters = {}
3233
for col_name, col_type in cols_metadata.items():
33-
pandas_type = data_types.athena2pandas(dtype=col_type)
34+
pandas_type = athena2pandas(dtype=col_type)
3435
if pandas_type in ["datetime64", "date"]:
3536
parse_timestamps.append(col_name)
3637
if pandas_type == "date":
@@ -122,6 +123,58 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
122123
self.wait_query(query_execution_id=query_id)
123124
return query_id
124125

126+
@staticmethod
127+
def _rows2row(rows: List[Dict[str, List[Dict[str, str]]]],
128+
python_types: List[Tuple[str, Optional[type]]]) -> Iterator[Dict[str, Any]]:
129+
for row in rows:
130+
vals_varchar: List[Optional[str]] = [x["VarCharValue"] if x else None for x in row["Data"]]
131+
data: Dict[str, Any] = {}
132+
for (name, ptype), val in zip(python_types, vals_varchar):
133+
if ptype is not None:
134+
data[name] = ptype(val)
135+
else:
136+
data[name] = None
137+
yield data
138+
139+
def get_results(self, query_execution_id: str) -> Iterator[Dict[str, Any]]:
140+
"""
141+
Get a query results and return a list of rows
142+
:param query_execution_id: Query execution ID
143+
:return: Iterator os lists
144+
"""
145+
res: Dict = self._client_athena.get_query_results(QueryExecutionId=query_execution_id)
146+
cols_info: List[Dict] = res["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
147+
athena_types: List[Tuple[str, str]] = [(x["Label"], x["Type"]) for x in cols_info]
148+
logger.info(f"athena_types: {athena_types}")
149+
python_types: List[Tuple[str, Optional[type]]] = [(n, athena2python(dtype=t)) for n, t in athena_types]
150+
logger.info(f"python_types: {python_types}")
151+
rows: List[Dict[str, List[Dict[str, str]]]] = res["ResultSet"]["Rows"][1:]
152+
for row in Athena._rows2row(rows=rows, python_types=python_types):
153+
yield row
154+
next_token: Optional[str] = res.get("NextToken")
155+
while next_token is not None:
156+
logger.info(f"next_token: {next_token}")
157+
res = self._client_athena.get_query_results(QueryExecutionId=query_execution_id, NextToken=next_token)
158+
rows = res["ResultSet"]["Rows"]
159+
for row in Athena._rows2row(rows=rows, python_types=python_types):
160+
yield row
161+
next_token = res.get("NextToken")
162+
163+
def query(self, query: str, database: str, s3_output: str = None,
164+
workgroup: str = None) -> Iterator[Dict[str, Any]]:
165+
"""
166+
Run a SQL Query against AWS Athena and return the result as a Iterator of lists
167+
168+
:param query: SQL query
169+
:param database: Glue database name
170+
:param s3_output: AWS S3 path
171+
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
172+
:return: Query execution ID
173+
"""
174+
query_id: str = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup)
175+
self.wait_query(query_execution_id=query_id)
176+
return self.get_results(query_execution_id=query_id)
177+
125178
@staticmethod
126179
def _normalize_name(name):
127180
name = "".join(c for c in unicodedata.normalize("NFD", name) if unicodedata.category(c) != "Mn")

awswrangler/data_types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Dict, Callable
1+
from typing import List, Tuple, Dict, Callable, Optional
22
import logging
33
from datetime import datetime, date
44

@@ -56,7 +56,7 @@ def athena2pyarrow(dtype: str) -> str:
5656
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
5757

5858

59-
def athena2python(dtype: str) -> type:
59+
def athena2python(dtype: str) -> Optional[type]:
6060
dtype = dtype.lower()
6161
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
6262
return int
@@ -70,6 +70,8 @@ def athena2python(dtype: str) -> type:
7070
return datetime
7171
elif dtype == "date":
7272
return date
73+
elif dtype == "unknown":
74+
return None
7375
else:
7476
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
7577

awswrangler/s3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def delete_objects(self, path):
7575
procs = []
7676
args = {"Bucket": bucket, "MaxKeys": 1000, "Prefix": path}
7777
logger.debug(f"Arguments: \n{args}")
78-
next_continuation_token = True
79-
while next_continuation_token:
78+
next_continuation_token = ""
79+
while next_continuation_token is not None:
8080
res = client.list_objects_v2(**args)
8181
if not res.get("Contents"):
8282
break

testing/test_awswrangler/test_athena.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,12 @@ def test_query_failed(session, database):
186186
query_execution_id = session.athena.run_query(query="SELECT random(-1)", database=database)
187187
with pytest.raises(QueryFailed):
188188
assert session.athena.wait_query(query_execution_id=query_execution_id)
189+
190+
191+
def test_query(session, database):
192+
row = list(session.athena.query(query="SELECT 'foo', 1, 2.0, true, null", database=database))[0]
193+
assert row["_col0"] == "foo"
194+
assert row["_col1"] == 1
195+
assert row["_col2"] == 2.0
196+
assert row["_col3"] is True
197+
assert row["_col4"] is None

0 commit comments

Comments
 (0)