Skip to content

Commit 5fd04b0

Browse files
Add/delete column to Glue Catalog table(#453)
1 parent 563c444 commit 5fd04b0

File tree

5 files changed

+187
-4
lines changed

5 files changed

+187
-4
lines changed

awswrangler/catalog/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Amazon Glue Catalog Module."""
22

3-
from awswrangler.catalog._add import add_csv_partitions, add_parquet_partitions # noqa
3+
from awswrangler.catalog._add import add_column, add_csv_partitions, add_parquet_partitions # noqa
44
from awswrangler.catalog._create import ( # noqa
55
_create_csv_table,
66
_create_parquet_table,
@@ -12,6 +12,7 @@
1212
)
1313
from awswrangler.catalog._delete import ( # noqa
1414
delete_all_partitions,
15+
delete_column,
1516
delete_database,
1617
delete_partitions,
1718
delete_table_if_exists,

awswrangler/catalog/_add.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77

88
from awswrangler import _utils, exceptions
99
from awswrangler._config import apply_configs
10-
from awswrangler.catalog._definitions import _csv_partition_definition, _parquet_partition_definition
10+
from awswrangler.catalog._definitions import (
11+
_check_column_type,
12+
_csv_partition_definition,
13+
_parquet_partition_definition,
14+
_update_table_definition,
15+
)
1116
from awswrangler.catalog._utils import _catalog_id, sanitize_table_name
1217

1318
_logger: logging.Logger = logging.getLogger(__name__)
@@ -157,3 +162,65 @@ def add_parquet_partitions(
157162
_add_partitions(
158163
database=database, table=table, boto3_session=boto3_session, inputs=inputs, catalog_id=catalog_id
159164
)
165+
166+
167+
@apply_configs
168+
def add_column(
169+
database: str,
170+
table: str,
171+
column_name: str,
172+
column_type: str = "string",
173+
column_comment: Optional[str] = None,
174+
boto3_session: Optional[boto3.Session] = None,
175+
catalog_id: Optional[str] = None,
176+
) -> None:
177+
"""Delete a column in a AWS Glue Catalog table.
178+
179+
Parameters
180+
----------
181+
database : str
182+
Database name.
183+
table : str
184+
Table name.
185+
column_name : str
186+
Column name
187+
column_type : str
188+
Column type.
189+
column_comment : str
190+
Column Comment
191+
boto3_session : boto3.Session(), optional
192+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
193+
catalog_id : str, optional
194+
The ID of the Data Catalog from which to retrieve Databases.
195+
If none is provided, the AWS account ID is used by default.
196+
197+
Returns
198+
-------
199+
None
200+
None
201+
202+
Examples
203+
--------
204+
>>> import awswrangler as wr
205+
>>> wr.catalog.add_column(
206+
... database='my_db',
207+
... table='my_table',
208+
... column_name='my_col',
209+
... column_type='int'
210+
... )
211+
"""
212+
if _check_column_type(column_type):
213+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
214+
table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
215+
table_input: Dict[str, Any] = _update_table_definition(table_res)
216+
table_input["StorageDescriptor"]["Columns"].append(
217+
{"Name": column_name, "Type": column_type, "Comment": column_comment}
218+
)
219+
res: Dict[str, Any] = client_glue.update_table(
220+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input)
221+
)
222+
if ("Errors" in res) and res["Errors"]:
223+
for error in res["Errors"]:
224+
if "ErrorDetail" in error:
225+
if "ErrorCode" in error["ErrorDetail"]:
226+
raise exceptions.ServiceApiError(str(res["Errors"]))

awswrangler/catalog/_definitions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,27 @@
55

66
_logger: logging.Logger = logging.getLogger(__name__)
77

8+
_LEGAL_COLUMN_TYPES = [
9+
"array",
10+
"bigint",
11+
"binary",
12+
"boolean",
13+
"char",
14+
"date",
15+
"decimal",
16+
"double",
17+
"float",
18+
"int",
19+
"interval",
20+
"map",
21+
"set",
22+
"smallint",
23+
"string",
24+
"struct",
25+
"timestamp",
26+
"tinyint",
27+
]
28+
829

930
def _parquet_table_definition(
1031
table: str, path: str, columns_types: Dict[str, str], partitions_types: Dict[str, str], compression: Optional[str]
@@ -138,3 +159,32 @@ def _csv_partition_definition(
138159
{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()
139160
]
140161
return definition
162+
163+
164+
def _check_column_type(column_type: str) -> bool:
165+
if column_type not in _LEGAL_COLUMN_TYPES:
166+
raise ValueError(f"{column_type} is not a legal data type.")
167+
return True
168+
169+
170+
def _update_table_definition(current_definition: Dict[str, Any]) -> Dict[str, Any]:
171+
definition: Dict[str, Any] = dict()
172+
keep_keys = [
173+
"Name",
174+
"Description",
175+
"Owner",
176+
"LastAccessTime",
177+
"LastAnalyzedTime",
178+
"Retention",
179+
"StorageDescriptor",
180+
"PartitionKeys",
181+
"ViewOriginalText",
182+
"ViewExpandedText",
183+
"TableType",
184+
"Parameters",
185+
"TargetTable",
186+
]
187+
for key in current_definition["Table"]:
188+
if key in keep_keys:
189+
definition[key] = current_definition["Table"][key]
190+
return definition

awswrangler/catalog/_delete.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""AWS Glue Catalog Delete Module."""
22

33
import logging
4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional
55

66
import boto3
77

8-
from awswrangler import _utils
8+
from awswrangler import _utils, exceptions
99
from awswrangler._config import apply_configs
10+
from awswrangler.catalog._definitions import _update_table_definition
1011
from awswrangler.catalog._get import _get_partitions
1112
from awswrangler.catalog._utils import _catalog_id
1213

@@ -181,3 +182,57 @@ def delete_all_partitions(
181182
boto3_session=boto3_session,
182183
)
183184
return partitions_values
185+
186+
187+
@apply_configs
188+
def delete_column(
189+
database: str,
190+
table: str,
191+
column_name: str,
192+
boto3_session: Optional[boto3.Session] = None,
193+
catalog_id: Optional[str] = None,
194+
) -> None:
195+
"""Delete a column in a AWS Glue Catalog table.
196+
197+
Parameters
198+
----------
199+
database : str
200+
Database name.
201+
table : str
202+
Table name.
203+
column_name : str
204+
Column name
205+
boto3_session : boto3.Session(), optional
206+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
207+
catalog_id : str, optional
208+
The ID of the Data Catalog from which to retrieve Databases.
209+
If none is provided, the AWS account ID is used by default.
210+
211+
Returns
212+
-------
213+
None
214+
None
215+
216+
Examples
217+
--------
218+
>>> import awswrangler as wr
219+
>>> wr.catalog.delete_column(
220+
... database='my_db',
221+
... table='my_table',
222+
... column_name='my_col',
223+
... )
224+
"""
225+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
226+
table_res: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
227+
table_input: Dict[str, Any] = _update_table_definition(table_res)
228+
table_input["StorageDescriptor"]["Columns"] = [
229+
i for i in table_input["StorageDescriptor"]["Columns"] if i["Name"] != column_name
230+
]
231+
res: Dict[str, Any] = client_glue.update_table(
232+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input)
233+
)
234+
if ("Errors" in res) and res["Errors"]:
235+
for error in res["Errors"]:
236+
if "ErrorDetail" in error:
237+
if "ErrorCode" in error["ErrorDetail"]:
238+
raise exceptions.ServiceApiError(str(res["Errors"]))

tests/test_athena.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ def test_catalog(path: str, glue_database: str, glue_table: str) -> None:
244244
assert len(tables) > 0
245245
for tbl in tables:
246246
assert tbl["DatabaseName"] == glue_database
247+
# add & delete column
248+
wr.catalog.add_column(
249+
database=glue_database, table=glue_table, column_name="col2", column_type="int", column_comment="comment"
250+
)
251+
dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table)
252+
assert len(dtypes) == 5
253+
assert dtypes["col2"] == "int"
254+
wr.catalog.delete_column(database=glue_database, table=glue_table, column_name="col2")
255+
dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table)
256+
assert len(dtypes) == 4
247257
# search
248258
tables = list(wr.catalog.search_tables(text="parquet", catalog_id=account_id))
249259
assert len(tables) > 0

0 commit comments

Comments
 (0)