Skip to content

Commit 1801ff6

Browse files
authored
Add exist_ok flag to create_database (#645)
* Add exist_ok flag to create_database * Refactoring based on review
1 parent 31ba2a0 commit 1801ff6

File tree

5 files changed

+371
-322
lines changed

5 files changed

+371
-322
lines changed

awswrangler/catalog/_create.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def create_database(
473473
name: str,
474474
description: Optional[str] = None,
475475
catalog_id: Optional[str] = None,
476+
exist_ok: bool = False,
476477
boto3_session: Optional[boto3.Session] = None,
477478
) -> None:
478479
"""Create a database in AWS Glue Catalog.
@@ -486,6 +487,9 @@ def create_database(
486487
catalog_id : str, optional
487488
The ID of the Data Catalog from which to retrieve Databases.
488489
If none is provided, the AWS account ID is used by default.
490+
exist_ok : bool
491+
If set to True will not raise an Exception if a Database with the same already exists.
492+
In this case the description will be updated if it is different from the current one.
489493
boto3_session : boto3.Session(), optional
490494
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
491495
@@ -501,16 +505,19 @@ def create_database(
501505
... name='awswrangler_test'
502506
... )
503507
"""
504-
args: Dict[str, str] = {}
505508
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
506-
args["Name"] = name
509+
args: Dict[str, str] = {"Name": name}
507510
if description is not None:
508511
args["Description"] = description
509512

510-
if catalog_id is not None:
511-
client_glue.create_database(CatalogId=catalog_id, DatabaseInput=args)
512-
else:
513-
client_glue.create_database(DatabaseInput=args)
513+
try:
514+
r = client_glue.get_database(Name=name)
515+
if not exist_ok:
516+
raise exceptions.AlreadyExists(f"Database {name} already exists and <exist_ok> is set to False.")
517+
if description and description != r["Database"].get("Description", ""):
518+
client_glue.update_database(**_catalog_id(catalog_id=catalog_id, Name=name, DatabaseInput=args))
519+
except client_glue.exceptions.EntityNotFoundException:
520+
client_glue.create_database(**_catalog_id(catalog_id=catalog_id, DatabaseInput=args))
514521

515522

516523
@apply_configs

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,7 @@ class InvalidFile(Exception):
107107

108108
class FailedQualityCheck(Exception):
109109
"""FailedQualityCheck."""
110+
111+
112+
class AlreadyExists(Exception):
113+
"""AlreadyExists."""

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_
170170
return "aws_data_wrangler_external"
171171

172172

173+
@pytest.fixture(scope="session")
174+
def account_id():
175+
return boto3.client("sts").get_caller_identity().get("Account")
176+
177+
173178
@pytest.fixture(scope="function")
174179
def glue_ctas_database():
175180
name = f"db_{get_time_str_with_random_suffix()}"
@@ -282,3 +287,10 @@ def assert_filename_prefix(filename, filename_prefix, test_prefix):
282287
assert not filename.startswith(test_prefix)
283288

284289
return assert_filename_prefix
290+
291+
292+
@pytest.fixture(scope="function")
293+
def random_glue_database():
294+
database_name = get_time_str_with_random_suffix()
295+
yield database_name
296+
wr.catalog.delete_database(database_name)

0 commit comments

Comments
 (0)