2323import asyncio
2424import logging
2525import re
26+ from contextlib import asynccontextmanager
2627from dataclasses import fields
2728from typing import Any , Dict , List , Optional , Tuple
2829
2930import kopf
3031from aiopg import Cursor
31- from kubernetes_asyncio .client import ApiException , CoreV1Api , CustomObjectsApi
32+ from kubernetes_asyncio .client import (
33+ ApiException ,
34+ AppsV1Api ,
35+ CoreV1Api ,
36+ CustomObjectsApi ,
37+ )
3238from psycopg2 import DatabaseError , ProgrammingError
33- from psycopg2 .errors import DuplicateTable
39+ from psycopg2 .errors import DuplicateTable , UndefinedTable
3440from psycopg2 .extensions import AsIs , QuotedString , quote_ident
3541
3642from crate .operator .config import config
5258 get_crash_scheme ,
5359 run_crash_command ,
5460 scale_backup_metrics_deployment ,
61+ suspend_or_start_grand_central ,
5562)
5663from crate .operator .restore_backup_repository_data import BackupRepositoryData
5764from crate .operator .utils import crate
@@ -370,23 +377,7 @@ def get_restore_keyword(self, *, cursor: Cursor):
370377 if not tables or (len (tables ) == 1 and tables [0 ].lower () == "all" ):
371378 return "ALL"
372379
373- def quote_table (table ):
374- """
375- Ensure table names are correctly quoted. If it contains a schema
376- (e.g., 'doc.nyc_taxi'), quote both the schema and the table using
377- psycopg2.extensions.quote_ident.
378- """
379- if "." in table :
380- schema , table_name = table .split ("." , 1 )
381- else :
382- schema , table_name = None , table
383-
384- quoted_schema = quote_ident (schema , cursor ._impl ) if schema else None
385- quoted_table = quote_ident (table_name , cursor ._impl )
386-
387- return f"{ quoted_schema } .{ quoted_table } " if quoted_schema else quoted_table
388-
389- formatted_tables = [quote_table (table .strip ()) for table in tables ]
380+ formatted_tables = [quote_table (table .strip (), cursor ) for table in tables ]
390381
391382 return f'TABLE { "," .join (formatted_tables )} '
392383
@@ -577,6 +568,7 @@ async def handle( # type: ignore
577568 ):
578569 async with GlobalApiClient () as api_client :
579570 core = CoreV1Api (api_client )
571+ apps = AppsV1Api (api_client )
580572 data = await get_source_backup_repository_data (
581573 core ,
582574 namespace ,
@@ -596,16 +588,35 @@ async def handle( # type: ignore
596588 conn_factory , repository , snapshot , logger
597589 )
598590
599- await self ._start_restore_snapshot (
591+ async with restore_internal_tables_context (
592+ apps ,
593+ namespace ,
594+ name ,
600595 conn_factory ,
601596 repository ,
602597 snapshot ,
603- restore_type ,
604598 logger ,
599+ restore_type ,
605600 tables ,
606- partitions ,
607- sections ,
608- )
601+ ) as internal_tables :
602+ await internal_tables .rename_duplicated_tables ()
603+
604+ try :
605+ await self ._start_restore_snapshot (
606+ conn_factory ,
607+ repository ,
608+ snapshot ,
609+ restore_type ,
610+ logger ,
611+ tables ,
612+ partitions ,
613+ sections ,
614+ )
615+ except Exception as e :
616+ await internal_tables .restore_tables ()
617+ raise e
618+ else :
619+ await internal_tables .cleanup_tables ()
609620
610621 @staticmethod
611622 async def _create_backup_repository (
@@ -1114,3 +1125,207 @@ async def handle( # type: ignore
11141125 name = name ,
11151126 body = body ,
11161127 )
1128+
1129+
1130+ @asynccontextmanager
1131+ async def restore_internal_tables_context (
1132+ apps ,
1133+ namespace ,
1134+ name ,
1135+ conn_factory ,
1136+ repository ,
1137+ snapshot ,
1138+ logger ,
1139+ restore_type ,
1140+ tables ,
1141+ ):
1142+ internal_tables = RestoreInternalTables (conn_factory , repository , snapshot , logger )
1143+ await internal_tables .set_gc_tables (restore_type , tables )
1144+ if internal_tables .has_tables_to_process ():
1145+ logger .info ("Suspending GC operations before restoring internal tables" )
1146+ await suspend_or_start_grand_central (apps , namespace , name , suspend = True )
1147+ try :
1148+ yield internal_tables
1149+ finally :
1150+ if internal_tables .has_tables_to_process ():
1151+ logger .info ("Resuming GC operations after restoring internal tables" )
1152+ await suspend_or_start_grand_central (apps , namespace , name , suspend = False )
1153+
1154+
1155+ class RestoreInternalTables :
1156+
1157+ def __init__ (
1158+ self ,
1159+ conn_factory ,
1160+ repository : str ,
1161+ snapshot : str ,
1162+ logger : logging .Logger ,
1163+ ):
1164+ self .conn_factory = conn_factory
1165+ self .repository : str = repository
1166+ self .snapshot : str = snapshot
1167+ self .logger : logging .Logger = logger
1168+
1169+ self .gc_tables : list [str ] = []
1170+
1171+ def has_tables_to_process (self ) -> bool :
1172+ return True if self .gc_tables else False
1173+
1174+ async def set_gc_tables (
1175+ self , restore_type : str , tables : Optional [list [str ]] = None
1176+ ):
1177+ """
1178+ Retrieve the grand central tables from the snapshot to be restored.
1179+ """
1180+
1181+ if restore_type not in [
1182+ SnapshotRestoreType .ALL .value ,
1183+ SnapshotRestoreType .TABLES .value ,
1184+ ]:
1185+ return
1186+
1187+ if restore_type == SnapshotRestoreType .TABLES .value and tables is not None :
1188+ gc_tables = [table for table in tables if table .startswith ("gc." )]
1189+
1190+ # There is no gc table to restore, no need to proceed further
1191+ if not gc_tables :
1192+ return
1193+
1194+ tables_str = "," .join (f"'{ table } '" for table in gc_tables )
1195+ where_stmt = f"t IN ({ tables_str } )"
1196+ else :
1197+ where_stmt = "t LIKE 'gc.%%'"
1198+
1199+ try :
1200+ async with self .conn_factory () as conn :
1201+ async with conn .cursor (timeout = 120 ) as cursor :
1202+ await cursor .execute (
1203+ "WITH tables AS ("
1204+ " SELECT unnest(tables) AS t "
1205+ " FROM sys.snapshots "
1206+ " WHERE repository=%s AND name=%s"
1207+ ") "
1208+ f"SELECT * FROM tables WHERE { where_stmt } ;" ,
1209+ (self .repository , self .snapshot ),
1210+ )
1211+ snapshot_gc_tables = await cursor .fetchall ()
1212+
1213+ if snapshot_gc_tables :
1214+ await cursor .execute ('SHOW TABLES FROM "gc";' )
1215+ existing_gc_tables = await cursor .fetchall ()
1216+
1217+ if existing_gc_tables :
1218+ existing_gc_tables = [
1219+ f"gc.{ table [0 ]} " for table in existing_gc_tables
1220+ ]
1221+ for (table ,) in snapshot_gc_tables :
1222+ if table in existing_gc_tables :
1223+ self .gc_tables .append (table )
1224+
1225+ except DatabaseError as e :
1226+ self .logger .warning (
1227+ "DatabaseError in RestoreInternalTables.set_gc_tables" ,
1228+ exc_info = e ,
1229+ )
1230+ raise kopf .PermanentError ("internal tables couldn't be retrieved." )
1231+
1232+ async def _rename_table (self , cursor , old_name : str , new_name : str ):
1233+ self .logger .info (f"Renaming GC table: { old_name } to { new_name } " )
1234+ try :
1235+ await cursor .execute (f"ALTER TABLE { old_name } RENAME TO { new_name } ;" )
1236+ except UndefinedTable :
1237+ self .logger .warning (f"Table { old_name } does not exist. Skipping." )
1238+ pass
1239+
1240+ async def rename_duplicated_tables (self ):
1241+ """
1242+ If the snapshot contains grand central tables, rename them if they exist
1243+ in the cluster in order to recreate the new ones from the snapshot.
1244+ """
1245+ if not self .has_tables_to_process ():
1246+ return
1247+
1248+ try :
1249+ async with self .conn_factory () as conn :
1250+ async with conn .cursor (timeout = 120 ) as cursor :
1251+ for table in self .gc_tables :
1252+ table_name = quote_table (table , cursor )
1253+ temp_table_name = table_without_schema (f"{ table } _temp" , cursor )
1254+ await self ._rename_table (cursor , table_name , temp_table_name )
1255+ except DatabaseError as e :
1256+ self .logger .warning (
1257+ "DatabaseError in RestoreInternalTables.rename_duplicated_tables" ,
1258+ exc_info = e ,
1259+ )
1260+ raise kopf .PermanentError ("internal tables couldn't be renamed." )
1261+
1262+ async def restore_tables (self ):
1263+ """
1264+ If the restore operation failed, rename back the gc tables
1265+ to their original names.
1266+ """
1267+ if not self .has_tables_to_process ():
1268+ return
1269+
1270+ try :
1271+ async with self .conn_factory () as conn :
1272+ async with conn .cursor (timeout = 120 ) as cursor :
1273+ for table in self .gc_tables :
1274+ table_name = table_without_schema (table , cursor )
1275+ temp_table_name = quote_table (f"{ table } _temp" , cursor )
1276+ await self ._rename_table (cursor , temp_table_name , table_name )
1277+
1278+ except DatabaseError as e :
1279+ self .logger .warning (
1280+ "DatabaseError in RestoreInternalTables.restore_tables" , exc_info = e
1281+ )
1282+ raise kopf .PermanentError ("internal table couldn't be renamed." )
1283+
1284+ async def cleanup_tables (self ):
1285+ """
1286+ After a successful restore, the temporary renamed gc tables can be dropped.
1287+ """
1288+ if not self .has_tables_to_process ():
1289+ return
1290+
1291+ try :
1292+ async with self .conn_factory () as conn :
1293+ async with conn .cursor (timeout = 120 ) as cursor :
1294+ for table in self .gc_tables :
1295+ temp_table_name = quote_table (f"{ table } _temp" , cursor )
1296+ self .logger .info (f"Dropping old GC table: { temp_table_name } " )
1297+ await cursor .execute (f"DROP TABLE IF EXISTS { temp_table_name } ;" )
1298+ except DatabaseError as e :
1299+ self .logger .warning (
1300+ "DatabaseError in RestoreInternalTables.cleanup_tables" , exc_info = e
1301+ )
1302+ raise kopf .PermanentError ("internal temporary table couldn't be dropped." )
1303+
1304+
1305+ def quote_table (table , cursor ) -> str :
1306+ """
1307+ Ensure table names are correctly quoted. If it contains a schema
1308+ (e.g., 'doc.nyc_taxi'), quote both the schema and the table using
1309+ psycopg2.extensions.quote_ident.
1310+ """
1311+ if "." in table :
1312+ schema , table_name = table .split ("." , 1 )
1313+ else :
1314+ schema , table_name = None , table
1315+
1316+ quoted_schema = quote_ident (schema , cursor ._impl ) if schema else None
1317+ quoted_table = quote_ident (table_name , cursor ._impl )
1318+
1319+ return f"{ quoted_schema } .{ quoted_table } " if quoted_schema else quoted_table
1320+
1321+
1322+ def table_without_schema (table , cursor ) -> str :
1323+ """
1324+ Returns the table name without schema, ensuring it's correctly quoted.
1325+
1326+ :param table: The full table name, possibly including schema.
1327+ :param cursor: The database cursor used for quoting.
1328+ :return: The quoted table name without schema.
1329+ """
1330+ table_name = table .split ("." )[1 ] if "." in table else table
1331+ return quote_ident (table_name , cursor ._impl )
0 commit comments