Skip to content

Commit 73a0987

Browse files
committed
Add catalog_versioning to parameters handling. #342
1 parent c26af10 commit 73a0987

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

awswrangler/catalog/_create.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements
163163
def _upsert_table_parameters(
164164
parameters: Dict[str, str],
165165
database: str,
166+
catalog_versioning: bool,
166167
catalog_id: Optional[str],
167168
table_input: Dict[str, Any],
168169
boto3_session: Optional[boto3.Session],
@@ -180,20 +181,25 @@ def _upsert_table_parameters(
180181
catalog_id=catalog_id,
181182
boto3_session=boto3_session,
182183
table_input=table_input,
184+
catalog_versioning=catalog_versioning,
183185
)
184186
return pars
185187

186188

187189
def _overwrite_table_parameters(
188190
parameters: Dict[str, str],
189191
database: str,
192+
catalog_versioning: bool,
190193
catalog_id: Optional[str],
191194
table_input: Dict[str, Any],
192195
boto3_session: Optional[boto3.Session],
193196
) -> Dict[str, str]:
194197
table_input["Parameters"] = parameters
195198
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
196-
client_glue.update_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input))
199+
skip_archive: bool = not catalog_versioning
200+
client_glue.update_table(
201+
**_catalog_id(catalog_id=catalog_id, DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive)
202+
)
197203
return parameters
198204

199205

@@ -346,6 +352,7 @@ def upsert_table_parameters(
346352
parameters: Dict[str, str],
347353
database: str,
348354
table: str,
355+
catalog_versioning: bool = False,
349356
catalog_id: Optional[str] = None,
350357
boto3_session: Optional[boto3.Session] = None,
351358
) -> Dict[str, str]:
@@ -359,6 +366,8 @@ def upsert_table_parameters(
359366
Database name.
360367
table : str
361368
Table name.
369+
catalog_versioning : bool
370+
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
362371
catalog_id : str, optional
363372
The ID of the Data Catalog from which to retrieve Databases.
364373
If none is provided, the AWS account ID is used by default.
@@ -386,7 +395,12 @@ def upsert_table_parameters(
386395
if table_input is None:
387396
raise exceptions.InvalidArgumentValue(f"Table {database}.{table} does not exist.")
388397
return _upsert_table_parameters(
389-
parameters=parameters, database=database, boto3_session=session, catalog_id=catalog_id, table_input=table_input,
398+
parameters=parameters,
399+
database=database,
400+
boto3_session=session,
401+
catalog_id=catalog_id,
402+
table_input=table_input,
403+
catalog_versioning=catalog_versioning,
390404
)
391405

392406

@@ -395,6 +409,7 @@ def overwrite_table_parameters(
395409
parameters: Dict[str, str],
396410
database: str,
397411
table: str,
412+
catalog_versioning: bool = False,
398413
catalog_id: Optional[str] = None,
399414
boto3_session: Optional[boto3.Session] = None,
400415
) -> Dict[str, str]:
@@ -408,6 +423,8 @@ def overwrite_table_parameters(
408423
Database name.
409424
table : str
410425
Table name.
426+
catalog_versioning : bool
427+
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
411428
catalog_id : str, optional
412429
The ID of the Data Catalog from which to retrieve Databases.
413430
If none is provided, the AWS account ID is used by default.
@@ -435,7 +452,12 @@ def overwrite_table_parameters(
435452
if table_input is None:
436453
raise exceptions.InvalidTable(f"Table {table} does not exist on database {database}.")
437454
return _overwrite_table_parameters(
438-
parameters=parameters, database=database, catalog_id=catalog_id, table_input=table_input, boto3_session=session,
455+
parameters=parameters,
456+
database=database,
457+
catalog_id=catalog_id,
458+
table_input=table_input,
459+
boto3_session=session,
460+
catalog_versioning=catalog_versioning,
439461
)
440462

441463

tests/test__routines.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
2727
use_threads=use_threads,
2828
concurrent_partitioning=concurrent_partitioning,
2929
)["paths"]
30+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
3031
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
3132
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
3233
assert df.shape == df2.shape
@@ -55,6 +56,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
5556
use_threads=use_threads,
5657
concurrent_partitioning=concurrent_partitioning,
5758
)["paths"]
59+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
5860
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
5961
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
6062
assert df.shape == df2.shape
@@ -83,6 +85,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
8385
use_threads=use_threads,
8486
concurrent_partitioning=concurrent_partitioning,
8587
)["paths"]
88+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
8689
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
8790
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
8891
assert len(df.columns) == len(df2.columns)
@@ -112,6 +115,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
112115
use_threads=use_threads,
113116
concurrent_partitioning=concurrent_partitioning,
114117
)["paths"]
118+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
115119
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
116120
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
117121
assert len(df2.columns) == 2
@@ -142,6 +146,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
142146
use_threads=use_threads,
143147
concurrent_partitioning=concurrent_partitioning,
144148
)["paths"]
149+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
145150
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
146151
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
147152
assert len(df2.columns) == 3
@@ -174,6 +179,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
174179
use_threads=use_threads,
175180
concurrent_partitioning=concurrent_partitioning,
176181
)["paths"]
182+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
177183
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
178184
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
179185
assert df.shape == df2.shape
@@ -204,6 +210,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
204210
concurrent_partitioning=concurrent_partitioning,
205211
use_threads=use_threads,
206212
)["paths"]
213+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
207214
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
208215
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
209216
assert len(df2.columns) == 2
@@ -235,6 +242,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
235242
use_threads=use_threads,
236243
concurrent_partitioning=concurrent_partitioning,
237244
)["paths"]
245+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
238246
wr.s3.wait_objects_exist(paths=paths, use_threads=use_threads)
239247
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads)
240248
assert len(df2.columns) == 3
@@ -268,6 +276,7 @@ def test_routine_1(glue_database, glue_table, path):
268276
parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))},
269277
columns_comments={"c0": "0"},
270278
)
279+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
271280
df2 = wr.athena.read_sql_table(glue_table, glue_database)
272281
assert df.shape == df2.shape
273282
assert df.c0.sum() == df2.c0.sum()
@@ -294,6 +303,7 @@ def test_routine_1(glue_database, glue_table, path):
294303
parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))},
295304
columns_comments={"c1": "1"},
296305
)
306+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
297307
df2 = wr.athena.read_sql_table(glue_table, glue_database)
298308
assert df.shape == df2.shape
299309
assert df.c1.sum() == df2.c1.sum()
@@ -320,6 +330,7 @@ def test_routine_1(glue_database, glue_table, path):
320330
parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)},
321331
columns_comments={"c1": "1"},
322332
)
333+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
323334
df2 = wr.athena.read_sql_table(glue_table, glue_database)
324335
assert len(df.columns) == len(df2.columns)
325336
assert len(df.index) * 2 == len(df2.index)
@@ -348,6 +359,7 @@ def test_routine_1(glue_database, glue_table, path):
348359
parameters={"num_cols": "2", "num_rows": "9"},
349360
columns_comments={"c1": "1", "c2": "2"},
350361
)
362+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
351363
df2 = wr.athena.read_sql_table(glue_table, glue_database)
352364
assert len(df2.columns) == 2
353365
assert len(df2.index) == 9
@@ -376,6 +388,7 @@ def test_routine_1(glue_database, glue_table, path):
376388
parameters={"num_cols": "2", "num_rows": "2"},
377389
columns_comments={"c0": "zero", "c1": "one"},
378390
)
391+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
379392
df2 = wr.athena.read_sql_table(glue_table, glue_database)
380393
assert df.shape == df2.shape
381394
assert df.c1.sum() == df2.c1.astype(int).sum()
@@ -405,6 +418,7 @@ def test_routine_1(glue_database, glue_table, path):
405418
parameters={"num_cols": "2", "num_rows": "3"},
406419
columns_comments={"c0": "zero", "c1": "one"},
407420
)
421+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
408422
df2 = wr.athena.read_sql_table(glue_table, glue_database)
409423
assert len(df2.columns) == 2
410424
assert len(df2.index) == 3
@@ -435,6 +449,7 @@ def test_routine_1(glue_database, glue_table, path):
435449
parameters={"num_cols": "3", "num_rows": "4"},
436450
columns_comments={"c0": "zero", "c1": "one", "c2": "two"},
437451
)
452+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
438453
df2 = wr.athena.read_sql_table(glue_table, glue_database)
439454
assert len(df2.columns) == 3
440455
assert len(df2.index) == 4

tests/test_athena.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def test_catalog_versioning(path, glue_database, glue_table):
531531
paths = wr.s3.to_parquet(
532532
df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite"
533533
)["paths"]
534+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1
534535
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
535536
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
536537
assert len(df.index) == 2
@@ -548,6 +549,7 @@ def test_catalog_versioning(path, glue_database, glue_table):
548549
mode="overwrite",
549550
catalog_versioning=True,
550551
)["paths"]
552+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 2
551553
wr.s3.wait_objects_exist(paths=paths1, use_threads=False)
552554
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
553555
assert len(df.index) == 2
@@ -566,6 +568,7 @@ def test_catalog_versioning(path, glue_database, glue_table):
566568
catalog_versioning=True,
567569
index=False,
568570
)["paths"]
571+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3
569572
wr.s3.wait_objects_exist(paths=paths2, use_threads=False)
570573
wr.s3.wait_objects_not_exist(paths=paths1, use_threads=False)
571574
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
@@ -585,6 +588,7 @@ def test_catalog_versioning(path, glue_database, glue_table):
585588
catalog_versioning=False,
586589
index=False,
587590
)["paths"]
591+
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3
588592
wr.s3.wait_objects_exist(paths=paths3, use_threads=False)
589593
wr.s3.wait_objects_not_exist(paths=paths2, use_threads=False)
590594
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)

0 commit comments

Comments
 (0)