11import logging
22import re
3+ from collections import defaultdict
34from collections .abc import Iterable , Iterator
45from dataclasses import dataclass
56from functools import partial
@@ -73,6 +74,12 @@ def sql_alter_from(self, catalog):
7374 f" TBLPROPERTIES ('upgraded_from' = '{ self .key } ');"
7475 )
7576
77+ def sql_unset_upgraded_to (self , catalog ):
78+ return (
79+ f"ALTER { self .kind } `{ catalog } `.`{ self .database } `.`{ self .name } ` "
80+ f"UNSET TBLPROPERTIES IF EXISTS('upgraded_to');"
81+ )
82+
7683
7784@dataclass
7885class TableError :
@@ -82,6 +89,14 @@ class TableError:
8289 error : str | None = None
8390
8491
92+ @dataclass
93+ class MigrationCount :
94+ database : str
95+ managed_tables : int = 0
96+ external_tables : int = 0
97+ views : int = 0
98+
99+
85100class TablesCrawler (CrawlerBase ):
86101 def __init__ (self , backend : SqlBackend , schema ):
87102 """
@@ -252,3 +267,103 @@ def _init_seen_tables(self):
252267
253268 def _table_already_upgraded (self , target ) -> bool :
254269 return target in self ._seen_tables
270+
271+ def _get_tables_to_revert (self , schema : str | None = None , table : str | None = None ) -> list [Table ]:
272+ schema = schema .lower () if schema else None
273+ table = table .lower () if table else None
274+ upgraded_tables = []
275+ if table and not schema :
276+ logger .error ("Cannot accept 'Table' parameter without 'Schema' parameter" )
277+ if len (self ._seen_tables ) == 0 :
278+ self ._init_seen_tables ()
279+
280+ for cur_table in self ._tc .snapshot ():
281+ if schema and cur_table .database != schema :
282+ continue
283+ if table and cur_table .name != table :
284+ continue
285+ if cur_table .key in self ._seen_tables .values ():
286+ upgraded_tables .append (cur_table )
287+ return upgraded_tables
288+
289+ def revert_migrated_tables (
290+ self , schema : str | None = None , table : str | None = None , * , delete_managed : bool = False
291+ ):
292+ upgraded_tables = self ._get_tables_to_revert (schema = schema , table = table )
293+ # reverses the _seen_tables dictionary to key by the source table
294+ reverse_seen = {v : k for (k , v ) in self ._seen_tables .items ()}
295+ tasks = []
296+ for upgraded_table in upgraded_tables :
297+ if upgraded_table .kind == "VIEW" or upgraded_table .object_type == "EXTERNAL" or delete_managed :
298+ tasks .append (partial (self ._revert_migrated_table , upgraded_table , reverse_seen [upgraded_table .key ]))
299+ continue
300+ logger .info (
301+ f"Skipping { upgraded_table .object_type } Table { upgraded_table .database } .{ upgraded_table .name } "
302+ f"upgraded_to { upgraded_table .upgraded_to } "
303+ )
304+ Threads .strict ("revert migrated tables" , tasks )
305+
306+ def _revert_migrated_table (self , table : Table , target_table_key : str ):
307+ logger .info (
308+ f"Reverting { table .object_type } table { table .database } .{ table .name } upgraded_to { table .upgraded_to } "
309+ )
310+ self ._backend .execute (table .sql_unset_upgraded_to ("hive_metastore" ))
311+ self ._backend .execute (f"DROP { table .kind } IF EXISTS { target_table_key } " )
312+
313+ def _get_revert_count (self , schema : str | None = None , table : str | None = None ) -> list [MigrationCount ]:
314+ upgraded_tables = self ._get_tables_to_revert (schema = schema , table = table )
315+
316+ table_by_database = defaultdict (list )
317+ for cur_table in upgraded_tables :
318+ table_by_database [cur_table .database ].append (cur_table )
319+
320+ migration_list = []
321+ for cur_database in table_by_database .keys ():
322+ external_tables = 0
323+ managed_tables = 0
324+ views = 0
325+ for current_table in table_by_database [cur_database ]:
326+ if current_table .upgraded_to is not None :
327+ if current_table .kind == "VIEW" :
328+ views += 1
329+ continue
330+ if current_table .object_type == "EXTERNAL" :
331+ external_tables += 1
332+ continue
333+ if current_table .object_type == "MANAGED" :
334+ managed_tables += 1
335+ continue
336+ migration_list .append (
337+ MigrationCount (
338+ database = cur_database , managed_tables = managed_tables , external_tables = external_tables , views = views
339+ )
340+ )
341+ return migration_list
342+
343+ def is_upgraded (self , schema : str , table : str ) -> bool :
344+ result = self ._backend .fetch (f"SHOW TBLPROPERTIES `{ schema } `.`{ table } `" )
345+ for value in result :
346+ if value ["key" ] == "upgraded_to" :
347+ logger .info (f"{ schema } .{ table } is set as upgraded" )
348+ return True
349+ logger .info (f"{ schema } .{ table } is set as not upgraded" )
350+ return False
351+
352+ def print_revert_report (self , * , delete_managed : bool ) -> bool | None :
353+ migrated_count = self ._get_revert_count ()
354+ if not migrated_count :
355+ logger .info ("No migrated tables were found." )
356+ return False
357+ print ("The following is the count of migrated tables and views found in scope:" )
358+ print ("Database | External Tables | Managed Table | Views |" )
359+ print ("=" * 88 )
360+ for count in migrated_count :
361+ print (f"{ count .database :<30} | { count .external_tables :16} | { count .managed_tables :16} | { count .views :16} |" )
362+ print ("=" * 88 )
363+ print ("Migrated External Tables and Views (targets) will be deleted" )
364+ if delete_managed :
365+ print ("Migrated Manged Tables (targets) will be deleted" )
366+ else :
367+ print ("Migrated Manged Tables (targets) will be left intact." )
368+ print ("To revert and delete Migrated Tables, add --delete_managed true flag to the command." )
369+ return True
0 commit comments