Skip to content

Commit c40e443

Browse files
authored
Merge pull request #243 from awslabs/schema-evolution
Schema Evolution
2 parents 9c36b70 + 6604c06 commit c40e443

File tree

7 files changed

+695
-105
lines changed

7 files changed

+695
-105
lines changed

awswrangler/_data_types.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,15 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
175175
) -> Optional[VisitableType]:
176176
"""Pyarrow to Athena data types conversion."""
177177
if pa.types.is_int8(dtype):
178-
return sqlalchemy.types.SMALLINT
178+
return sqlalchemy.types.SmallInteger
179179
if pa.types.is_int16(dtype):
180-
return sqlalchemy.types.SMALLINT
180+
return sqlalchemy.types.SmallInteger
181181
if pa.types.is_int32(dtype):
182-
return sqlalchemy.types.INTEGER
182+
return sqlalchemy.types.Integer
183183
if pa.types.is_int64(dtype):
184-
return sqlalchemy.types.BIGINT
184+
return sqlalchemy.types.BigInteger
185185
if pa.types.is_float32(dtype):
186-
return sqlalchemy.types.FLOAT
186+
return sqlalchemy.types.Float
187187
if pa.types.is_float64(dtype):
188188
if db_type == "mysql":
189189
return sqlalchemy.dialects.mysql.DOUBLE
@@ -195,25 +195,25 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
195195
f"{db_type} is a invalid database type, please choose between postgresql, mysql and redshift."
196196
) # pragma: no cover
197197
if pa.types.is_boolean(dtype):
198-
return sqlalchemy.types.BOOLEAN
198+
return sqlalchemy.types.Boolean
199199
if pa.types.is_string(dtype):
200200
if db_type == "mysql":
201-
return sqlalchemy.types.TEXT
201+
return sqlalchemy.types.Text
202202
if db_type == "postgresql":
203-
return sqlalchemy.types.TEXT
203+
return sqlalchemy.types.Text
204204
if db_type == "redshift":
205205
return sqlalchemy.types.VARCHAR(length=256)
206206
raise exceptions.InvalidDatabaseType(
207207
f"{db_type} is a invalid database type. " f"Please choose between postgresql, mysql and redshift."
208208
) # pragma: no cover
209209
if pa.types.is_timestamp(dtype):
210-
return sqlalchemy.types.DATETIME
210+
return sqlalchemy.types.DateTime
211211
if pa.types.is_date(dtype):
212-
return sqlalchemy.types.DATE
212+
return sqlalchemy.types.Date
213213
if pa.types.is_binary(dtype):
214214
if db_type == "redshift":
215215
raise exceptions.UnsupportedType("Binary columns are not supported for Redshift.") # pragma: no cover
216-
return sqlalchemy.types.BINARY
216+
return sqlalchemy.types.Binary
217217
if pa.types.is_decimal(dtype):
218218
return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale)
219219
if pa.types.is_dictionary(dtype):
@@ -396,7 +396,7 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
396396
df[col] = (
397397
df[col]
398398
.astype("string")
399-
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None)
399+
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
400400
)
401401
elif pandas_type == "string":
402402
curr_type: str = str(df[col].dtypes)
@@ -405,7 +405,16 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
405405
else:
406406
df[col] = df[col].astype("string")
407407
else:
408-
df[col] = df[col].astype(pandas_type)
408+
try:
409+
df[col] = df[col].astype(pandas_type)
410+
except TypeError as ex:
411+
if "object cannot be converted to an IntegerDtype" not in str(ex):
412+
raise ex # pragma: no cover
413+
df[col] = (
414+
df[col]
415+
.apply(lambda x: int(x) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
416+
.astype(pandas_type)
417+
)
409418
return df
410419

411420

awswrangler/catalog.py

Lines changed: 158 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,33 @@ def create_parquet_table(
150150
"""
151151
table = sanitize_table_name(table=table)
152152
partitions_types = {} if partitions_types is None else partitions_types
153-
table_input: Dict[str, Any] = _parquet_table_definition(
154-
table=table, path=path, columns_types=columns_types, partitions_types=partitions_types, compression=compression
155-
)
153+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
154+
cat_table_input: Optional[Dict[str, Any]] = _get_table_input(database=database, table=table, boto3_session=session)
155+
table_input: Dict[str, Any]
156+
if (cat_table_input is not None) and (mode in ("append", "overwrite_partitions")):
157+
table_input = cat_table_input
158+
updated: bool = False
159+
cat_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
160+
for c, t in columns_types.items():
161+
if c not in cat_cols:
162+
_logger.debug("New column %s with type %s.", c, t)
163+
table_input["StorageDescriptor"]["Columns"].append({"Name": c, "Type": t})
164+
updated = True
165+
elif t != cat_cols[c]: # Data type change detected!
166+
raise exceptions.InvalidArgumentValue(
167+
f"Data type change detected on column {c}. Old type: {cat_cols[c]}. New type {t}."
168+
)
169+
if updated is True:
170+
mode = "update"
171+
else:
172+
table_input = _parquet_table_definition(
173+
table=table,
174+
path=path,
175+
columns_types=columns_types,
176+
partitions_types=partitions_types,
177+
compression=compression,
178+
)
179+
table_exist: bool = cat_table_input is not None
156180
_create_table(
157181
database=database,
158182
table=table,
@@ -161,8 +185,9 @@ def create_parquet_table(
161185
columns_comments=columns_comments,
162186
mode=mode,
163187
catalog_versioning=catalog_versioning,
164-
boto3_session=boto3_session,
188+
boto3_session=session,
165189
table_input=table_input,
190+
table_exist=table_exist,
166191
)
167192

168193

@@ -266,7 +291,9 @@ def _parquet_partition_definition(location: str, values: List[str], compression:
266291
}
267292

268293

269-
def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, str]:
294+
def get_table_types(
295+
database: str, table: str, boto3_session: Optional[boto3.Session] = None
296+
) -> Optional[Dict[str, str]]:
270297
"""Get all columns and types from a table.
271298
272299
Parameters
@@ -280,8 +307,8 @@ def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Ses
280307
281308
Returns
282309
-------
283-
Dict[str, str]
284-
A dictionary as {'col name': 'col data type'}.
310+
Optional[Dict[str, str]]
311+
If table exists, a dictionary like {'col name': 'col data type'}. Otherwise None.
285312
286313
Examples
287314
--------
@@ -291,7 +318,10 @@ def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Ses
291318
292319
"""
293320
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
294-
response: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
321+
try:
322+
response: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)
323+
except client_glue.exceptions.EntityNotFoundException:
324+
return None
295325
dtypes: Dict[str, str] = {}
296326
for col in response["Table"]["StorageDescriptor"]["Columns"]:
297327
dtypes[col["Name"]] = col["Type"]
@@ -938,6 +968,7 @@ def create_csv_table(
938968
compression=compression,
939969
sep=sep,
940970
)
971+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
941972
_create_table(
942973
database=database,
943974
table=table,
@@ -946,8 +977,9 @@ def create_csv_table(
946977
columns_comments=columns_comments,
947978
mode=mode,
948979
catalog_versioning=catalog_versioning,
949-
boto3_session=boto3_session,
980+
boto3_session=session,
950981
table_input=table_input,
982+
table_exist=does_table_exist(database=database, table=table, boto3_session=session),
951983
)
952984

953985

@@ -961,6 +993,7 @@ def _create_table(
961993
catalog_versioning: bool,
962994
boto3_session: Optional[boto3.Session],
963995
table_input: Dict[str, Any],
996+
table_exist: bool,
964997
):
965998
if description is not None:
966999
table_input["Description"] = description
@@ -978,23 +1011,25 @@ def _create_table(
9781011
par["Comment"] = columns_comments[name]
9791012
session: boto3.Session = _utils.ensure_session(session=boto3_session)
9801013
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
981-
exist: bool = does_table_exist(database=database, table=table, boto3_session=session)
982-
if mode not in ("overwrite", "append", "overwrite_partitions"): # pragma: no cover
1014+
skip_archive: bool = not catalog_versioning
1015+
if mode not in ("overwrite", "append", "overwrite_partitions", "update"): # pragma: no cover
9831016
raise exceptions.InvalidArgument(
9841017
f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'."
9851018
)
986-
if (exist is True) and (mode == "overwrite"):
987-
skip_archive: bool = not catalog_versioning
1019+
if (table_exist is True) and (mode == "overwrite"):
9881020
partitions_values: List[List[str]] = list(
9891021
_get_partitions(database=database, table=table, boto3_session=session).values()
9901022
)
9911023
client_glue.batch_delete_partition(
9921024
DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values]
9931025
)
9941026
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
995-
elif (exist is True) and (mode in ("append", "overwrite_partitions")) and (parameters is not None):
996-
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
997-
elif exist is False:
1027+
elif (table_exist is True) and (mode in ("append", "overwrite_partitions", "update")):
1028+
if parameters is not None:
1029+
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
1030+
if mode == "update":
1031+
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
1032+
elif table_exist is False:
9981033
client_glue.create_table(DatabaseName=database, TableInput=table_input)
9991034

10001035

@@ -1379,6 +1414,88 @@ def get_table_parameters(
13791414
return parameters
13801415

13811416

1417+
def get_table_description(
1418+
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
1419+
) -> str:
1420+
"""Get table description.
1421+
1422+
Parameters
1423+
----------
1424+
database : str
1425+
Database name.
1426+
table : str
1427+
Table name.
1428+
catalog_id : str, optional
1429+
The ID of the Data Catalog from which to retrieve Databases.
1430+
If none is provided, the AWS account ID is used by default.
1431+
boto3_session : boto3.Session(), optional
1432+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1433+
1434+
Returns
1435+
-------
1436+
str
1437+
Description.
1438+
1439+
Examples
1440+
--------
1441+
>>> import awswrangler as wr
1442+
>>> desc = wr.catalog.get_table_description(database="...", table="...")
1443+
1444+
"""
1445+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
1446+
args: Dict[str, str] = {}
1447+
if catalog_id is not None:
1448+
args["CatalogId"] = catalog_id # pragma: no cover
1449+
args["DatabaseName"] = database
1450+
args["Name"] = table
1451+
response: Dict[str, Any] = client_glue.get_table(**args)
1452+
desc: str = response["Table"]["Description"]
1453+
return desc
1454+
1455+
1456+
def get_columns_comments(
1457+
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
1458+
) -> Dict[str, str]:
1459+
"""Get all columns comments.
1460+
1461+
Parameters
1462+
----------
1463+
database : str
1464+
Database name.
1465+
table : str
1466+
Table name.
1467+
catalog_id : str, optional
1468+
The ID of the Data Catalog from which to retrieve Databases.
1469+
If none is provided, the AWS account ID is used by default.
1470+
boto3_session : boto3.Session(), optional
1471+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1472+
1473+
Returns
1474+
-------
1475+
Dict[str, str]
1476+
Columns comments. e.g. {"col1": "foo boo bar"}.
1477+
1478+
Examples
1479+
--------
1480+
>>> import awswrangler as wr
1481+
>>> pars = wr.catalog.get_table_parameters(database="...", table="...")
1482+
1483+
"""
1484+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
1485+
args: Dict[str, str] = {}
1486+
if catalog_id is not None:
1487+
args["CatalogId"] = catalog_id # pragma: no cover
1488+
args["DatabaseName"] = database
1489+
args["Name"] = table
1490+
response: Dict[str, Any] = client_glue.get_table(**args)
1491+
comments: Dict[str, str] = {}
1492+
for c in response["Table"]["StorageDescriptor"]["Columns"]:
1493+
comments[c["Name"]] = c["Comment"]
1494+
for p in response["Table"]["PartitionKeys"]:
1495+
comments[p["Name"]] = p["Comment"]
1496+
return comments
1497+
1498+
13821499
def upsert_table_parameters(
13831500
parameters: Dict[str, str],
13841501
database: str,
@@ -1465,14 +1582,36 @@ def overwrite_table_parameters(
14651582
... table="...")
14661583
14671584
"""
1585+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
1586+
table_input: Optional[Dict[str, Any]] = _get_table_input(
1587+
database=database, table=table, catalog_id=catalog_id, boto3_session=session
1588+
)
1589+
if table_input is None:
1590+
raise exceptions.InvalidTable(f"Table {table} does not exist on database {database}.")
1591+
table_input["Parameters"] = parameters
1592+
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
1593+
if catalog_id is not None:
1594+
args2["CatalogId"] = catalog_id # pragma: no cover
1595+
args2["DatabaseName"] = database
1596+
args2["TableInput"] = table_input
1597+
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
1598+
client_glue.update_table(**args2)
1599+
return parameters
1600+
1601+
1602+
def _get_table_input(
1603+
database: str, table: str, boto3_session: Optional[boto3.Session], catalog_id: Optional[str] = None
1604+
) -> Optional[Dict[str, str]]:
14681605
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
14691606
args: Dict[str, str] = {}
14701607
if catalog_id is not None:
14711608
args["CatalogId"] = catalog_id # pragma: no cover
14721609
args["DatabaseName"] = database
14731610
args["Name"] = table
1474-
response: Dict[str, Any] = client_glue.get_table(**args)
1475-
response["Table"]["Parameters"] = parameters
1611+
try:
1612+
response: Dict[str, Any] = client_glue.get_table(**args)
1613+
except client_glue.exceptions.EntityNotFoundException:
1614+
return None
14761615
if "DatabaseName" in response["Table"]:
14771616
del response["Table"]["DatabaseName"]
14781617
if "CreateTime" in response["Table"]:
@@ -1483,10 +1622,4 @@ def overwrite_table_parameters(
14831622
del response["Table"]["CreatedBy"]
14841623
if "IsRegisteredWithLakeFormation" in response["Table"]:
14851624
del response["Table"]["IsRegisteredWithLakeFormation"]
1486-
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
1487-
if catalog_id is not None:
1488-
args2["CatalogId"] = catalog_id # pragma: no cover
1489-
args2["DatabaseName"] = database
1490-
args2["TableInput"] = response["Table"]
1491-
client_glue.update_table(**args2)
1492-
return parameters
1625+
return response["Table"]

0 commit comments

Comments
 (0)