Skip to content

Commit a4df88b

Browse files
committed
Add Glue.get_table_dtypes and Glue.get_table_python_types.
1 parent ca08638 commit a4df88b

File tree

3 files changed

+150
-28
lines changed

3 files changed

+150
-28
lines changed

awswrangler/athena.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,24 @@
99
class Athena:
1010
def __init__(self, session):
1111
self._session = session
12+
self._client_athena = session.boto3_session.client(
13+
service_name="athena", config=session.botocore_config)
1214

1315
def run_query(self, query, database, s3_output):
14-
client = self._session.boto3_session.client(
15-
service_name="athena", config=self._session.botocore_config)
16-
response = client.start_query_execution(
16+
response = self._client_athena.start_query_execution(
1717
QueryString=query,
1818
QueryExecutionContext={"Database": database},
1919
ResultConfiguration={"OutputLocation": s3_output},
2020
)
2121
return response["QueryExecutionId"]
2222

2323
def wait_query(self, query_execution_id):
24-
client = self._session.boto3_session.client(
25-
service_name="athena", config=self._session.botocore_config)
2624
final_states = ["FAILED", "SUCCEEDED", "CANCELLED"]
27-
response = client.get_query_execution(
25+
response = self._client_athena.get_query_execution(
2826
QueryExecutionId=query_execution_id)
2927
while (response.get("QueryExecution").get("Status").get("State") not in
3028
final_states):
3129
sleep(QUERY_WAIT_POLLING_DELAY)
32-
response = client.get_query_execution(
30+
response = self._client_athena.get_query_execution(
3331
QueryExecutionId=query_execution_id)
3432
return response

awswrangler/glue.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,53 @@
1010
class Glue:
1111
def __init__(self, session):
1212
self._session = session
13+
self._client_glue = session.boto3_session.client(
14+
service_name="glue", config=session.botocore_config)
15+
16+
def get_table_dtypes(self, database, table):
17+
"""
18+
Get all columns names and the related data types
19+
:param database: Glue database's name
20+
:param table: Glue table's name
21+
:return: A dictionary as {"col name": "col dtype"}
22+
"""
23+
response = self._client_glue.get_table(DatabaseName=database,
24+
Name=table)
25+
logger.debug(f"get_table response:\n{response}")
26+
dtypes = {}
27+
for col in response["Table"]["StorageDescriptor"]["Columns"]:
28+
dtypes[col["Name"]] = col["Type"]
29+
for par in response["Table"]["PartitionKeys"]:
30+
dtypes[par["Name"]] = par["Type"]
31+
return dtypes
32+
33+
def get_table_python_types(self, database, table):
34+
"""
35+
Get all columns names and the related python types
36+
:param database: Glue database's name
37+
:param table: Glue table's name
38+
:return: A dictionary as {"col name": "col python type"}
39+
"""
40+
dtypes = self.get_table_dtypes(database=database, table=table)
41+
return {k: Glue._type_athena2python(v) for k, v in dtypes.items()}
42+
43+
@staticmethod
44+
def _type_athena2python(dtype):
45+
dtype = dtype.lower()
46+
if dtype == "int":
47+
return int
48+
elif dtype == "bigint":
49+
return int
50+
elif dtype == "float":
51+
return float
52+
elif dtype == "double":
53+
return float
54+
elif dtype == "boolean":
55+
return bool
56+
elif dtype == "string":
57+
return str
58+
else:
59+
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
1360

1461
def metadata_to_glue(
1562
self,
@@ -53,20 +100,16 @@ def metadata_to_glue(
53100
)
54101

55102
def delete_table_if_exists(self, database, table):
56-
client = self._session.boto3_session.client(
57-
service_name="glue", config=self._session.botocore_config)
58103
try:
59-
client.delete_table(DatabaseName=database, Name=table)
60-
except client.exceptions.EntityNotFoundException:
104+
self._client_glue.delete_table(DatabaseName=database, Name=table)
105+
except self._client_glue.exceptions.EntityNotFoundException:
61106
pass
62107

63108
def does_table_exists(self, database, table):
64-
client = self._session.boto3_session.client(
65-
service_name="glue", config=self._session.botocore_config)
66109
try:
67-
client.get_table(DatabaseName=database, Name=table)
110+
self._client_glue.get_table(DatabaseName=database, Name=table)
68111
return True
69-
except client.exceptions.EntityNotFoundException:
112+
except self._client_glue.exceptions.EntityNotFoundException:
70113
return False
71114

72115
def create_table(self,
@@ -76,8 +119,6 @@ def create_table(self,
76119
path,
77120
file_format,
78121
partition_cols=None):
79-
client = self._session.boto3_session.client(
80-
service_name="glue", config=self._session.botocore_config)
81122
if file_format == "parquet":
82123
table_input = Glue.parquet_table_definition(
83124
table, partition_cols, schema, path)
@@ -86,11 +127,10 @@ def create_table(self,
86127
schema, path)
87128
else:
88129
raise UnsupportedFileFormat(file_format)
89-
client.create_table(DatabaseName=database, TableInput=table_input)
130+
self._client_glue.create_table(DatabaseName=database,
131+
TableInput=table_input)
90132

91133
def add_partitions(self, database, table, partition_paths, file_format):
92-
client = self._session.boto3_session.client(
93-
service_name="glue", config=self._session.botocore_config)
94134
if not partition_paths:
95135
return None
96136
partitions = list()
@@ -106,15 +146,13 @@ def add_partitions(self, database, table, partition_paths, file_format):
106146
for _ in range(pages_num):
107147
page = partitions[:100]
108148
del partitions[:100]
109-
client.batch_create_partition(DatabaseName=database,
110-
TableName=table,
111-
PartitionInputList=page)
149+
self._client_glue.batch_create_partition(DatabaseName=database,
150+
TableName=table,
151+
PartitionInputList=page)
112152

113153
def get_connection_details(self, name):
114-
client = self._session.boto3_session.client(
115-
service_name="glue", config=self._session.botocore_config)
116-
return client.get_connection(Name=name,
117-
HidePassword=False)["Connection"]
154+
return self._client_glue.get_connection(
155+
Name=name, HidePassword=False)["Connection"]
118156

119157
@staticmethod
120158
def _build_schema(dataframe, partition_cols, preserve_index):
@@ -155,7 +193,7 @@ def _type_pandas2athena(dtype):
155193
elif dtype[:10] == "datetime64":
156194
return "timestamp"
157195
else:
158-
raise UnsupportedType("Unsupported Pandas type: " + dtype)
196+
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
159197

160198
@staticmethod
161199
def _parse_table_name(path):
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import logging
2+
3+
import pytest
4+
import boto3
5+
import pandas
6+
7+
from awswrangler import Session
8+
9+
logging.basicConfig(
10+
level=logging.INFO,
11+
format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
12+
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
13+
14+
15+
@pytest.fixture(scope="module")
16+
def cloudformation_outputs():
17+
response = boto3.client("cloudformation").describe_stacks(
18+
StackName="aws-data-wrangler-test-arena")
19+
outputs = {}
20+
for output in response.get("Stacks")[0].get("Outputs"):
21+
outputs[output.get("OutputKey")] = output.get("OutputValue")
22+
yield outputs
23+
24+
25+
@pytest.fixture(scope="module")
26+
def session():
27+
yield Session()
28+
29+
30+
@pytest.fixture(scope="module")
31+
def bucket(session, cloudformation_outputs):
32+
if "BucketName" in cloudformation_outputs:
33+
bucket = cloudformation_outputs.get("BucketName")
34+
session.s3.delete_objects(path=f"s3://{bucket}/")
35+
else:
36+
raise Exception("You must deploy the test infrastructure using SAM!")
37+
yield bucket
38+
session.s3.delete_objects(path=f"s3://{bucket}/")
39+
40+
41+
@pytest.fixture(scope="module")
42+
def database(cloudformation_outputs):
43+
if "GlueDatabaseName" in cloudformation_outputs:
44+
database = cloudformation_outputs.get("GlueDatabaseName")
45+
else:
46+
raise Exception("You must deploy the test infrastructure using SAM!")
47+
yield database
48+
49+
50+
@pytest.fixture(scope="module")
51+
def table(
52+
session,
53+
bucket,
54+
database,
55+
):
56+
dataframe = pandas.read_csv("data_samples/micro.csv")
57+
path = f"s3://{bucket}/test/"
58+
table = "test"
59+
session.pandas.to_parquet(dataframe=dataframe,
60+
database=database,
61+
table=table,
62+
path=path,
63+
preserve_index=False,
64+
mode="overwrite",
65+
procs_cpu_bound=1,
66+
partition_cols=["name", "date"])
67+
yield table
68+
session.glue.delete_table_if_exists(database=database, table=table)
69+
session.s3.delete_objects(path=path)
70+
71+
72+
def test_get_table_dtypes(session, database, table):
73+
dtypes = session.glue.get_table_dtypes(database=database, table=table)
74+
assert dtypes["id"] == "bigint"
75+
assert dtypes["value"] == "double"
76+
assert dtypes["name"] == "string"
77+
assert dtypes["date"] == "string"
78+
79+
80+
def test_get_table_python_types(session, database, table):
81+
ptypes = session.glue.get_table_python_types(database=database,
82+
table=table)
83+
assert ptypes["id"] == int
84+
assert ptypes["value"] == float
85+
assert ptypes["name"] == str
86+
assert ptypes["date"] == str

0 commit comments

Comments
 (0)