Skip to content

Commit a1f6271

Browse files
committed
Adding glue catalog utilities
1 parent 8245e9e commit a1f6271

File tree

3 files changed

+444
-1
lines changed

3 files changed

+444
-1
lines changed

awswrangler/glue.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, Any, Iterator
22
from math import ceil
3+
from itertools import islice
34
import re
45
import logging
56

7+
from pandas import DataFrame
8+
69
from awswrangler import data_types
710
from awswrangler.athena import Athena
811
from awswrangler.exceptions import UnsupportedFileFormat, InvalidSerDe, ApiError, UnsupportedType, UndetectedType, InvalidTable, InvalidArguments
@@ -390,3 +393,127 @@ def get_table_location(self, database: str, table: str):
390393
return res["Table"]["StorageDescriptor"]["Location"]
391394
except KeyError:
392395
raise InvalidTable(f"{database}.{table}")
396+
397+
def get_databases(self, catalog_id: Optional[str] = None) -> Iterator[Dict[str, Any]]:
398+
"""
399+
Get an iterator of databases
400+
401+
:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
402+
:return: Iterator[Dict[str, Any]] of Databases
403+
"""
404+
paginator = self._client_glue.get_paginator("get_databases")
405+
if catalog_id is None:
406+
response_iterator = paginator.paginate()
407+
else:
408+
response_iterator = paginator.paginate(CatalogId=catalog_id)
409+
for page in response_iterator:
410+
for db in page["DatabaseList"]:
411+
yield db
412+
413+
def get_tables(self, catalog_id: Optional[str] = None, database: Optional[str] = None, search: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> Iterator[Dict[str, Any]]:
414+
"""
415+
Get an iterator of tables
416+
417+
:param catalog_id: The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default.
418+
:param database: Filter a specific database
419+
:param search: Select by a specific string on table name
420+
:param prefix: Select by a specific prefix on table name
421+
:param suffix: Select by a specific suffix on table name
422+
:return: Iterator[Dict[str, Any]] of Tables
423+
"""
424+
paginator = self._client_glue.get_paginator("get_tables")
425+
args: Dict[str, str] = {}
426+
if catalog_id is not None:
427+
args["CatalogId"] = catalog_id
428+
if (prefix is not None) and (suffix is not None) and (search is not None):
429+
args["Expression"] = f"{prefix}.*{search}.*{suffix}"
430+
elif (prefix is not None) and (suffix is not None):
431+
args["Expression"] = f"{prefix}.*{suffix}"
432+
elif search is not None:
433+
args["Expression"] = f".*{search}.*"
434+
elif prefix is not None:
435+
args["Expression"] = f"{prefix}.*"
436+
elif suffix is not None:
437+
args["Expression"] = f".*{suffix}"
438+
if database is not None:
439+
databases = [database]
440+
else:
441+
databases = [x["Name"] for x in self.get_databases(catalog_id=catalog_id)]
442+
for db in databases:
443+
args["DatabaseName"] = db
444+
response_iterator = paginator.paginate(**args)
445+
for page in response_iterator:
446+
for tbl in page["TableList"]:
447+
yield tbl
448+
449+
def tables(self, limit: int = 100, catalog_id: Optional[str] = None, database: Optional[str] = None, search: Optional[str] = None, prefix: Optional[str] = None, suffix: Optional[str] = None) -> DataFrame:
450+
table_iter = self.get_tables(catalog_id=catalog_id, database=database, search=search, prefix=prefix, suffix=suffix)
451+
tables = islice(table_iter, limit)
452+
df_dict = {
453+
"Database": [],
454+
"Table": [],
455+
"Description": [],
456+
"Columns": [],
457+
"Partitions": []
458+
}
459+
for table in tables:
460+
df_dict["Database"].append(table["DatabaseName"])
461+
df_dict["Table"].append(table["Name"])
462+
if "Description" in table:
463+
df_dict["Description"].append(table["Description"])
464+
else:
465+
df_dict["Description"].append("")
466+
df_dict["Columns"].append(", ".join([x["Name"] for x in table["StorageDescriptor"]["Columns"]]))
467+
df_dict["Partitions"].append(", ".join([x["Name"] for x in table["PartitionKeys"]]))
468+
return DataFrame(data=df_dict)
469+
470+
def databases(self, limit: int = 100, catalog_id: Optional[str] = None) -> DataFrame:
471+
database_iter = self.get_databases(catalog_id=catalog_id)
472+
dbs = islice(database_iter, limit)
473+
df_dict = {
474+
"Database": [],
475+
"Description": []
476+
}
477+
for db in dbs:
478+
df_dict["Database"].append(db["Name"])
479+
if "Description" in db:
480+
df_dict["Description"].append(db["Description"])
481+
else:
482+
df_dict["Description"].append("")
483+
return DataFrame(data=df_dict)
484+
485+
def table(self, database: str, name: str, catalog_id: Optional[str] = None) -> DataFrame:
486+
if catalog_id is None:
487+
table: Dict[str, Any] = self._client_glue.get_table(
488+
DatabaseName=database,
489+
Name=name
490+
)["Table"]
491+
else:
492+
table = self._client_glue.get_table(
493+
CatalogId=catalog_id,
494+
DatabaseName=database,
495+
Name=name
496+
)["Table"]
497+
df_dict = {
498+
"Column Name": [],
499+
"Type": [],
500+
"Partition": [],
501+
"Comment": []
502+
}
503+
for col in table["StorageDescriptor"]["Columns"]:
504+
df_dict["Column Name"].append(col["Name"])
505+
df_dict["Type"].append(col["Type"])
506+
df_dict["Partition"].append(False)
507+
if "Comment" in table:
508+
df_dict["Comment"].append(table["Comment"])
509+
else:
510+
df_dict["Comment"].append("")
511+
for col in table["PartitionKeys"]:
512+
df_dict["Column Name"].append(col["Name"])
513+
df_dict["Type"].append(col["Type"])
514+
df_dict["Partition"].append(True)
515+
if "Comment" in table:
516+
df_dict["Comment"].append(table["Comment"])
517+
else:
518+
df_dict["Comment"].append("")
519+
return DataFrame(data=df_dict)

0 commit comments

Comments
 (0)