Skip to content

Commit 92e9f20

Browse files
committed
Improve table creation flow.
1 parent 1f3b45c commit 92e9f20

File tree

10 files changed

+539
-273
lines changed

10 files changed

+539
-273
lines changed

awswrangler/catalog/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from awswrangler.catalog._add import add_csv_partitions, add_parquet_partitions # noqa
44
from awswrangler.catalog._create import ( # noqa
5+
_create_csv_table,
6+
_create_parquet_table,
57
create_csv_table,
68
create_database,
79
create_parquet_table,
@@ -10,6 +12,7 @@
1012
)
1113
from awswrangler.catalog._delete import delete_database, delete_table_if_exists # noqa
1214
from awswrangler.catalog._get import ( # noqa
15+
_get_table_input,
1316
databases,
1417
get_columns_comments,
1518
get_connection,
@@ -20,8 +23,10 @@
2023
get_partitions,
2124
get_table_description,
2225
get_table_location,
26+
get_table_number_of_versions,
2327
get_table_parameters,
2428
get_table_types,
29+
get_table_versions,
2530
get_tables,
2631
search_tables,
2732
table,

awswrangler/catalog/_create.py

Lines changed: 374 additions & 209 deletions
Large diffs are not rendered by default.

awswrangler/catalog/_get.py

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from awswrangler import _utils, exceptions
1414
from awswrangler._config import apply_configs
15-
from awswrangler.catalog._utils import _extract_dtypes_from_table_details
15+
from awswrangler.catalog._utils import _catalog_id, _extract_dtypes_from_table_details
1616

1717
_logger: logging.Logger = logging.getLogger(__name__)
1818

@@ -21,16 +21,12 @@ def _get_table_input(
2121
database: str, table: str, boto3_session: Optional[boto3.Session], catalog_id: Optional[str] = None
2222
) -> Optional[Dict[str, Any]]:
2323
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
24-
args: Dict[str, str] = {}
25-
if catalog_id is not None:
26-
args["CatalogId"] = catalog_id
27-
args["DatabaseName"] = database
28-
args["Name"] = table
2924
try:
30-
response: Dict[str, Any] = client_glue.get_table(**args)
25+
response: Dict[str, Any] = client_glue.get_table(
26+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table)
27+
)
3128
except client_glue.exceptions.EntityNotFoundException:
3229
return None
33-
3430
table_input: Dict[str, Any] = {}
3531
for k, v in response["Table"].items():
3632
if k in [
@@ -49,7 +45,6 @@ def _get_table_input(
4945
"TargetTable",
5046
]:
5147
table_input[k] = v
52-
5348
return table_input
5449

5550

@@ -162,10 +157,7 @@ def get_databases(
162157
"""
163158
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
164159
paginator = client_glue.get_paginator("get_databases")
165-
if catalog_id is None:
166-
response_iterator: Iterator = paginator.paginate()
167-
else:
168-
response_iterator = paginator.paginate(CatalogId=catalog_id)
160+
response_iterator = paginator.paginate(**_catalog_id(catalog_id=catalog_id))
169161
for page in response_iterator:
170162
for db in page["DatabaseList"]:
171163
yield db
@@ -436,10 +428,7 @@ def table(
436428
437429
"""
438430
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
439-
if catalog_id is None:
440-
tbl: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table)["Table"]
441-
else:
442-
tbl = client_glue.get_table(CatalogId=catalog_id, DatabaseName=database, Name=table)["Table"]
431+
tbl = client_glue.get_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table))["Table"]
443432
df_dict: Dict[str, List] = {"Column Name": [], "Type": [], "Partition": [], "Comment": []}
444433
for col in tbl["StorageDescriptor"]["Columns"]:
445434
df_dict["Column Name"].append(col["Name"])
@@ -522,10 +511,7 @@ def get_connection(
522511
523512
"""
524513
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
525-
args: Dict[str, Any] = {"Name": name, "HidePassword": False}
526-
if catalog_id is not None:
527-
args["CatalogId"] = catalog_id
528-
return client_glue.get_connection(**args)["Connection"]
514+
return client_glue.get_connection(**_catalog_id(catalog_id=catalog_id, Name=name, HidePassword=False))["Connection"]
529515

530516

531517
def get_engine(
@@ -812,12 +798,9 @@ def get_table_parameters(
812798
813799
"""
814800
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
815-
args: Dict[str, str] = {}
816-
if catalog_id is not None:
817-
args["CatalogId"] = catalog_id
818-
args["DatabaseName"] = database
819-
args["Name"] = table
820-
response: Dict[str, Any] = client_glue.get_table(**args)
801+
response: Dict[str, Any] = client_glue.get_table(
802+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table)
803+
)
821804
parameters: Dict[str, str] = response["Table"]["Parameters"]
822805
return parameters
823806

@@ -851,16 +834,14 @@ def get_table_description(
851834
852835
"""
853836
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
854-
args: Dict[str, str] = {}
855-
if catalog_id is not None:
856-
args["CatalogId"] = catalog_id
857-
args["DatabaseName"] = database
858-
args["Name"] = table
859-
response: Dict[str, Any] = client_glue.get_table(**args)
837+
response: Dict[str, Any] = client_glue.get_table(
838+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table)
839+
)
860840
desc: Optional[str] = response["Table"].get("Description", None)
861841
return desc
862842

863843

844+
@apply_configs
864845
def get_columns_comments(
865846
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
866847
) -> Dict[str, str]:
@@ -890,16 +871,91 @@ def get_columns_comments(
890871
891872
"""
892873
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
893-
args: Dict[str, str] = {}
894-
if catalog_id is not None:
895-
args["CatalogId"] = catalog_id
896-
args["DatabaseName"] = database
897-
args["Name"] = table
898-
response: Dict[str, Any] = client_glue.get_table(**args)
874+
response: Dict[str, Any] = client_glue.get_table(
875+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table)
876+
)
899877
comments: Dict[str, str] = {}
900878
for c in response["Table"]["StorageDescriptor"]["Columns"]:
901879
comments[c["Name"]] = c["Comment"]
902880
if "PartitionKeys" in response["Table"]:
903881
for p in response["Table"]["PartitionKeys"]:
904882
comments[p["Name"]] = p["Comment"]
905883
return comments
884+
885+
886+
@apply_configs
887+
def get_table_versions(
888+
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
889+
) -> List[Dict[str, Any]]:
890+
"""Get all versions.
891+
892+
Parameters
893+
----------
894+
database : str
895+
Database name.
896+
table : str
897+
Table name.
898+
catalog_id : str, optional
899+
The ID of the Data Catalog from which to retrieve Databases.
900+
If none is provided, the AWS account ID is used by default.
901+
boto3_session : boto3.Session(), optional
902+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
903+
904+
Returns
905+
-------
906+
List[Dict[str, Any]
907+
List of table inputs:
908+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue.html#Glue.Client.get_table_versions
909+
910+
Examples
911+
--------
912+
>>> import awswrangler as wr
913+
>>> tables_versions = wr.catalog.get_table_versions(database="...", table="...")
914+
915+
"""
916+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
917+
paginator = client_glue.get_paginator("get_table_versions")
918+
versions: List[Dict[str, Any]] = []
919+
response_iterator = paginator.paginate(**_catalog_id(DatabaseName=database, TableName=table, catalog_id=catalog_id))
920+
for page in response_iterator:
921+
for tbl in page["TableVersions"]:
922+
versions.append(tbl)
923+
return versions
924+
925+
926+
@apply_configs
927+
def get_table_number_of_versions(
928+
database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
929+
) -> int:
930+
"""Get tatal number of versions.
931+
932+
Parameters
933+
----------
934+
database : str
935+
Database name.
936+
table : str
937+
Table name.
938+
catalog_id : str, optional
939+
The ID of the Data Catalog from which to retrieve Databases.
940+
If none is provided, the AWS account ID is used by default.
941+
boto3_session : boto3.Session(), optional
942+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
943+
944+
Returns
945+
-------
946+
int
947+
Total number of versions.
948+
949+
Examples
950+
--------
951+
>>> import awswrangler as wr
952+
>>> num = wr.catalog.get_table_number_of_versions(database="...", table="...")
953+
954+
"""
955+
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
956+
paginator = client_glue.get_paginator("get_table_versions")
957+
count: int = 0
958+
response_iterator = paginator.paginate(**_catalog_id(DatabaseName=database, TableName=table, catalog_id=catalog_id))
959+
for page in response_iterator:
960+
count += len(page["TableVersions"])
961+
return count

awswrangler/s3/_write.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Amazon CSV S3 Write Module (PRIVATE)."""
22

33
import logging
4-
from typing import Dict, List, Optional, Tuple
4+
from typing import Any, Dict, List, Optional, Tuple
55

6-
import boto3 # type: ignore
76
import pandas as pd # type: ignore
87

98
from awswrangler import _data_types, _utils, catalog, exceptions
@@ -13,21 +12,25 @@
1312
_COMPRESSION_2_EXT: Dict[Optional[str], str] = {None: "", "gzip": ".gz", "snappy": ".snappy"}
1413

1514

15+
def _extract_dtypes_from_table_input(table_input: Dict[str, Any]) -> Dict[str, str]:
16+
dtypes: Dict[str, str] = {}
17+
for col in table_input["StorageDescriptor"]["Columns"]:
18+
dtypes[col["Name"]] = col["Type"]
19+
if "PartitionKeys" in table_input:
20+
for par in table_input["PartitionKeys"]:
21+
dtypes[par["Name"]] = par["Type"]
22+
return dtypes
23+
24+
1625
def _apply_dtype(
17-
df: pd.DataFrame,
18-
mode: str,
19-
database: Optional[str],
20-
table: Optional[str],
21-
dtype: Dict[str, str],
22-
boto3_session: boto3.Session,
26+
df: pd.DataFrame, dtype: Dict[str, str], catalog_table_input: Optional[Dict[str, Any]], mode: str
2327
) -> pd.DataFrame:
24-
if (mode in ("append", "overwrite_partitions")) and (database is not None) and (table is not None):
25-
catalog_types: Optional[Dict[str, str]] = catalog.get_table_types(
26-
database=database, table=table, boto3_session=boto3_session
27-
)
28-
if catalog_types is not None:
29-
for k, v in catalog_types.items():
30-
dtype[k] = v
28+
if mode in ("append", "overwrite_partitions"):
29+
if catalog_table_input is not None:
30+
catalog_types: Optional[Dict[str, str]] = _extract_dtypes_from_table_input(table_input=catalog_table_input)
31+
if catalog_types is not None:
32+
for k, v in catalog_types.items():
33+
dtype[k] = v
3134
df = _data_types.cast_pandas_with_athena_types(df=df, dtype=dtype)
3235
return df
3336

awswrangler/s3/_write_concurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def write(self, func: Callable, boto3_session: boto3.Session, **func_kwargs) ->
3737
if self._exec is not None:
3838
_logger.debug("Submitting: %s", func)
3939
future = self._exec.submit(
40-
fn=_WriteProxy._caller,
40+
_WriteProxy._caller,
4141
func=func,
4242
boto3_primitives=_utils.boto3_to_primitives(boto3_session=boto3_session),
4343
func_kwargs=func_kwargs,

awswrangler/s3/_write_parquet.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import uuid
5-
from typing import Dict, List, Optional, Tuple, Union
5+
from typing import Any, Dict, List, Optional, Tuple, Union
66

77
import boto3 # type: ignore
88
import pandas as pd # type: ignore
@@ -95,6 +95,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
9595
projection_values: Optional[Dict[str, str]] = None,
9696
projection_intervals: Optional[Dict[str, str]] = None,
9797
projection_digits: Optional[Dict[str, str]] = None,
98+
catalog_id: Optional[str] = None,
9899
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
99100
"""Write Parquet file or dataset on Amazon S3.
100101
@@ -196,6 +197,9 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
196197
Dictionary of partitions names and Athena projections digits.
197198
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
198199
(e.g. {'col_name': '1', 'col2_name': '2'})
200+
catalog_id : str, optional
201+
The ID of the Data Catalog from which to retrieve Databases.
202+
If none is provided, the AWS account ID is used by default.
199203
200204
Returns
201205
-------
@@ -333,7 +337,12 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
333337
df, dtype, partition_cols = _sanitize(df=df, dtype=dtype, partition_cols=partition_cols)
334338

335339
# Evaluating dtype
336-
df = _apply_dtype(df=df, mode=mode, database=database, table=table, dtype=dtype, boto3_session=session)
340+
catalog_table_input: Optional[Dict[str, Any]] = None
341+
if database is not None and table is not None:
342+
catalog_table_input: Optional[Dict[str, Any]] = catalog._get_table_input( # pylint: disable=protected-access
343+
database=database, table=table, boto3_session=session, catalog_id=catalog_id
344+
)
345+
df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode)
337346
schema: pa.Schema = _data_types.pyarrow_schema_from_pandas(
338347
df=df, index=index, ignore_cols=partition_cols, dtype=dtype
339348
)
@@ -376,7 +385,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
376385
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
377386
df=df, index=index, partition_cols=partition_cols, dtype=dtype
378387
)
379-
catalog.create_parquet_table(
388+
catalog._create_parquet_table( # pylint: disable=protected-access
380389
database=database,
381390
table=table,
382391
path=path,
@@ -395,6 +404,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
395404
projection_values=projection_values,
396405
projection_intervals=projection_intervals,
397406
projection_digits=projection_digits,
407+
catalog_id=catalog_id,
408+
catalog_table_input=catalog_table_input,
398409
)
399410
if partitions_values and (regular_partitions is True):
400411
_logger.debug("partitions_values:\n%s", partitions_values)

0 commit comments

Comments
 (0)