Skip to content

Commit a87867a

Browse files
committed
Add get_table_parameters, upsert_table_parameters, upsert_table_parameters. #224
1 parent 014228f commit a87867a

File tree

4 files changed

+260
-35
lines changed

4 files changed

+260
-35
lines changed

awswrangler/catalog.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import re
77
import unicodedata
8-
from typing import Any, Dict, Iterator, List, Optional, Tuple
8+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
99
from urllib.parse import quote_plus
1010

1111
import boto3 # type: ignore
@@ -989,6 +989,8 @@ def _create_table(
989989
DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values]
990990
)
991991
client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
992+
elif (exist is True) and (mode == "append") and (parameters is not None):
993+
upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session)
992994
elif exist is False:
993995
client_glue.create_table(DatabaseName=database, TableInput=table_input)
994996

@@ -1333,3 +1335,155 @@ def extract_athena_types(
13331335
return _data_types.athena_types_from_pandas_partitioned(
13341336
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=index_left
13351337
)
1338+
1339+
1340+
def get_table_parameters(
1341+
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
1342+
) -> Dict[str, str]:
1343+
"""Get all parameters.
1344+
1345+
Parameters
1346+
----------
1347+
database : str
1348+
Database name.
1349+
table : str
1350+
Table name.
1351+
catalog_id : str, optional
1352+
The ID of the Data Catalog from which to retrieve Databases.
1353+
If none is provided, the AWS account ID is used by default.
1354+
boto3_session : boto3.Session(), optional
1355+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1356+
1357+
Returns
1358+
-------
1359+
Dict[str, str]
1360+
Dictionary of parameters.
1361+
1362+
Examples
1363+
--------
1364+
>>> import awswrangler as wr
1365+
>>> pars = wr.catalog.get_table_parameters(database="...", table="...")
1366+
1367+
"""
1368+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
1369+
args: Dict[str, str] = {}
1370+
if catalog_id is not None:
1371+
args["CatalogId"] = catalog_id # pragma: no cover
1372+
args["DatabaseName"] = database
1373+
args["Name"] = table
1374+
response: Dict[str, Any] = client_glue.get_table(**args)
1375+
parameters: Dict[str, str] = response["Table"]["Parameters"]
1376+
return parameters
1377+
1378+
1379+
def upsert_table_parameters(
1380+
parameters: Dict[str, str],
1381+
database: str,
1382+
table: str,
1383+
catalog_id: Optional[str] = None,
1384+
boto3_session: Optional[boto3.Session] = None,
1385+
) -> Dict[str, str]:
1386+
"""Insert or Update the received parameters.
1387+
1388+
Parameters
1389+
----------
1390+
parameters : Dict[str, str]
1391+
e.g. {"source": "mysql", "destination": "datalake"}
1392+
database : str
1393+
Database name.
1394+
table : str
1395+
Table name.
1396+
catalog_id : str, optional
1397+
The ID of the Data Catalog from which to retrieve Databases.
1398+
If none is provided, the AWS account ID is used by default.
1399+
boto3_session : boto3.Session(), optional
1400+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1401+
1402+
Returns
1403+
-------
1404+
Dict[str, str]
1405+
All parameters after the upsert.
1406+
1407+
Examples
1408+
--------
1409+
>>> import awswrangler as wr
1410+
>>> pars = wr.catalog.upsert_table_parameters(
1411+
... parameters={"source": "mysql", "destination": "datalake"},
1412+
... database="...",
1413+
... table="...")
1414+
1415+
"""
1416+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
1417+
pars: Dict[str, str] = get_table_parameters(
1418+
database=database, table=table, catalog_id=catalog_id, boto3_session=session
1419+
)
1420+
for k, v in parameters.items():
1421+
pars[k] = v
1422+
overwrite_table_parameters(
1423+
parameters=pars, database=database, table=table, catalog_id=catalog_id, boto3_session=session
1424+
)
1425+
return pars
1426+
1427+
1428+
def overwrite_table_parameters(
1429+
parameters: Dict[str, str],
1430+
database: str,
1431+
table: str,
1432+
catalog_id: Optional[str] = None,
1433+
boto3_session: Optional[boto3.Session] = None,
1434+
) -> Dict[str, str]:
1435+
"""Overwrite all existing parameters.
1436+
1437+
Parameters
1438+
----------
1439+
parameters : Dict[str, str]
1440+
e.g. {"source": "mysql", "destination": "datalake"}
1441+
database : str
1442+
Database name.
1443+
table : str
1444+
Table name.
1445+
catalog_id : str, optional
1446+
The ID of the Data Catalog from which to retrieve Databases.
1447+
If none is provided, the AWS account ID is used by default.
1448+
boto3_session : boto3.Session(), optional
1449+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1450+
1451+
Returns
1452+
-------
1453+
Dict[str, str]
1454+
All parameters after the overwrite (The same received).
1455+
1456+
Examples
1457+
--------
1458+
>>> import awswrangler as wr
1459+
>>> pars = wr.catalog.overwrite_table_parameters(
1460+
... parameters={"source": "mysql", "destination": "datalake"},
1461+
... database="...",
1462+
... table="...")
1463+
1464+
"""
1465+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
1466+
args: Dict[str, str] = {}
1467+
if catalog_id is not None:
1468+
args["CatalogId"] = catalog_id # pragma: no cover
1469+
args["DatabaseName"] = database
1470+
args["Name"] = table
1471+
response: Dict[str, Any] = client_glue.get_table(**args)
1472+
response["Table"]["Parameters"] = parameters
1473+
if "DatabaseName" in response["Table"]:
1474+
del response["Table"]["DatabaseName"]
1475+
if "CreateTime" in response["Table"]:
1476+
del response["Table"]["CreateTime"]
1477+
if "UpdateTime" in response["Table"]:
1478+
del response["Table"]["UpdateTime"]
1479+
if "CreatedBy" in response["Table"]:
1480+
del response["Table"]["CreatedBy"]
1481+
if "IsRegisteredWithLakeFormation" in response["Table"]:
1482+
del response["Table"]["IsRegisteredWithLakeFormation"]
1483+
args2: Dict[str, Union[str, Dict[str, Any]]] = {}
1484+
if catalog_id is not None:
1485+
args2["CatalogId"] = catalog_id # pragma: no cover
1486+
args2["DatabaseName"] = database
1487+
args2["TableInput"] = response["Table"]
1488+
client_glue.update_table(**args2)
1489+
return parameters

awswrangler/s3.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ def to_csv( # pylint: disable=too-many-arguments
453453
The table name and all column names will be automatically sanitize using
454454
`wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`.
455455
456+
Note
457+
----
458+
On `append` mode, the `parameters` will be upsert on an existing table.
459+
456460
Note
457461
----
458462
In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count().
@@ -640,15 +644,14 @@ def to_csv( # pylint: disable=too-many-arguments
640644
paths = [path]
641645
else:
642646
mode = "append" if mode is None else mode
643-
exist: bool = False
644647
if columns:
645648
df = df[columns]
646649
if (database is not None) and (table is not None): # Normalize table to respect Athena's standards
647650
df = catalog.sanitize_dataframe_columns_names(df=df)
648651
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
649652
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
650653
columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()}
651-
exist = catalog.does_table_exist(database=database, table=table, boto3_session=session)
654+
exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session)
652655
if (exist is True) and (mode in ("append", "overwrite_partitions")):
653656
for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items():
654657
dtype[k] = v
@@ -669,21 +672,20 @@ def to_csv( # pylint: disable=too-many-arguments
669672
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
670673
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
671674
)
672-
if (exist is False) or (mode == "overwrite"):
673-
catalog.create_csv_table(
674-
database=database,
675-
table=table,
676-
path=path,
677-
columns_types=columns_types,
678-
partitions_types=partitions_types,
679-
description=description,
680-
parameters=parameters,
681-
columns_comments=columns_comments,
682-
boto3_session=session,
683-
mode="overwrite",
684-
catalog_versioning=catalog_versioning,
685-
sep=sep,
686-
)
675+
catalog.create_csv_table(
676+
database=database,
677+
table=table,
678+
path=path,
679+
columns_types=columns_types,
680+
partitions_types=partitions_types,
681+
description=description,
682+
parameters=parameters,
683+
columns_comments=columns_comments,
684+
boto3_session=session,
685+
mode=mode,
686+
catalog_versioning=catalog_versioning,
687+
sep=sep,
688+
)
687689
if partitions_values:
688690
_logger.debug("partitions_values:\n%s", partitions_values)
689691
catalog.add_csv_partitions(
@@ -869,6 +871,10 @@ def to_parquet( # pylint: disable=too-many-arguments
869871
The table name and all column names will be automatically sanitize using
870872
`wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`.
871873
874+
Note
875+
----
876+
On `append` mode, the `parameters` will be upsert on an existing table.
877+
872878
Note
873879
----
874880
In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count().
@@ -1058,13 +1064,12 @@ def to_parquet( # pylint: disable=too-many-arguments
10581064
]
10591065
else:
10601066
mode = "append" if mode is None else mode
1061-
exist: bool = False
10621067
if (database is not None) and (table is not None): # Normalize table to respect Athena's standards
10631068
df = catalog.sanitize_dataframe_columns_names(df=df)
10641069
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
10651070
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
10661071
columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()}
1067-
exist = catalog.does_table_exist(database=database, table=table, boto3_session=session)
1072+
exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session)
10681073
if (exist is True) and (mode in ("append", "overwrite_partitions")):
10691074
for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items():
10701075
dtype[k] = v
@@ -1087,21 +1092,20 @@ def to_parquet( # pylint: disable=too-many-arguments
10871092
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
10881093
df=df, index=index, partition_cols=partition_cols, dtype=dtype
10891094
)
1090-
if (exist is False) or (mode == "overwrite"):
1091-
catalog.create_parquet_table(
1092-
database=database,
1093-
table=table,
1094-
path=path,
1095-
columns_types=columns_types,
1096-
partitions_types=partitions_types,
1097-
compression=compression,
1098-
description=description,
1099-
parameters=parameters,
1100-
columns_comments=columns_comments,
1101-
boto3_session=session,
1102-
mode="overwrite",
1103-
catalog_versioning=catalog_versioning,
1104-
)
1095+
catalog.create_parquet_table(
1096+
database=database,
1097+
table=table,
1098+
path=path,
1099+
columns_types=columns_types,
1100+
partitions_types=partitions_types,
1101+
compression=compression,
1102+
description=description,
1103+
parameters=parameters,
1104+
columns_comments=columns_comments,
1105+
boto3_session=session,
1106+
mode=mode,
1107+
catalog_versioning=catalog_versioning,
1108+
)
11051109
if partitions_values:
11061110
_logger.debug("partitions_values:\n%s", partitions_values)
11071111
catalog.add_parquet_partitions(
@@ -1865,6 +1869,10 @@ def store_parquet_metadata(
18651869
The concept of Dataset goes beyond the simple idea of files and enable more
18661870
complex features like partitioning and catalog integration (AWS Glue Catalog).
18671871
1872+
Note
1873+
----
1874+
On `append` mode, the `parameters` will be upsert on an existing table.
1875+
18681876
Note
18691877
----
18701878
In case of `use_threads=True` the number of threads that will be spawned will be get from os.cpu_count().

docs/source/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ AWS Glue Catalog
6363
drop_duplicated_columns
6464
get_engine
6565
extract_athena_types
66+
get_table_parameters
67+
upsert_table_parameters
68+
upsert_table_parameters
6669

6770
Amazon Athena
6871
-------------

testing/test_awswrangler/test_data_lake.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,3 +1478,63 @@ def test_parquet_overwrite_partition_cols(bucket, database, external_schema):
14781478

14791479
wr.s3.delete_objects(path=path)
14801480
wr.catalog.delete_table_if_exists(database=database, table=table)
1481+
1482+
1483+
def test_catalog_parameters(bucket, database):
1484+
table = "test_catalog_parameters"
1485+
path = f"s3://{bucket}/{table}/"
1486+
wr.s3.delete_objects(path=path)
1487+
wr.catalog.delete_table_if_exists(database=database, table=table)
1488+
1489+
wr.s3.to_parquet(
1490+
df=pd.DataFrame({"c0": [1, 2]}),
1491+
path=path,
1492+
dataset=True,
1493+
database=database,
1494+
table=table,
1495+
mode="overwrite",
1496+
parameters={"a": "1", "b": "2"},
1497+
)
1498+
pars = wr.catalog.get_table_parameters(database=database, table=table)
1499+
assert pars["a"] == "1"
1500+
assert pars["b"] == "2"
1501+
pars["a"] = "0"
1502+
pars["c"] = "3"
1503+
wr.catalog.upsert_table_parameters(parameters=pars, database=database, table=table)
1504+
pars = wr.catalog.get_table_parameters(database=database, table=table)
1505+
assert pars["a"] == "0"
1506+
assert pars["b"] == "2"
1507+
assert pars["c"] == "3"
1508+
wr.catalog.overwrite_table_parameters(parameters={"d": "4"}, database=database, table=table)
1509+
pars = wr.catalog.get_table_parameters(database=database, table=table)
1510+
assert pars.get("a") is None
1511+
assert pars.get("b") is None
1512+
assert pars.get("c") is None
1513+
assert pars["d"] == "4"
1514+
df = wr.athena.read_sql_table(table=table, database=database)
1515+
assert len(df.index) == 2
1516+
assert len(df.columns) == 1
1517+
assert df.c0.sum() == 3
1518+
1519+
wr.s3.to_parquet(
1520+
df=pd.DataFrame({"c0": [3, 4]}),
1521+
path=path,
1522+
dataset=True,
1523+
database=database,
1524+
table=table,
1525+
mode="append",
1526+
parameters={"e": "5"},
1527+
)
1528+
pars = wr.catalog.get_table_parameters(database=database, table=table)
1529+
assert pars.get("a") is None
1530+
assert pars.get("b") is None
1531+
assert pars.get("c") is None
1532+
assert pars["d"] == "4"
1533+
assert pars["e"] == "5"
1534+
df = wr.athena.read_sql_table(table=table, database=database)
1535+
assert len(df.index) == 4
1536+
assert len(df.columns) == 1
1537+
assert df.c0.sum() == 10
1538+
1539+
wr.s3.delete_objects(path=path)
1540+
wr.catalog.delete_table_if_exists(database=database, table=table)

0 commit comments

Comments
 (0)