Skip to content
Merged
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Changelog
Unreleased
----------

* Remove grand-central tables when restoring a full snapshot or grand-central tables.

2.53.0 (2025-09-25)
-------------------

Expand Down
38 changes: 18 additions & 20 deletions crate/operator/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,17 +663,9 @@ async def suspend_or_start_cluster(
apps, namespace, backup_metrics_name, 1
)
# scale grand central deployment back up if it exists
deployment = await read_grand_central_deployment(
namespace=namespace, name=name
await suspend_or_start_grand_central(
apps, namespace, name, suspend=False
)

if deployment:
await update_deployment_replicas(
apps,
namespace,
f"{GRAND_CENTRAL_RESOURCE_PREFIX}-{name}",
1,
)
await send_operation_progress_notification(
namespace=namespace,
name=name,
Expand Down Expand Up @@ -726,17 +718,9 @@ async def suspend_or_start_cluster(
apps, namespace, backup_metrics_name, 0
)
# scale grand central deployment down if it exists
deployment = await read_grand_central_deployment(
namespace=namespace, name=name
await suspend_or_start_grand_central(
apps, namespace, name, suspend=True
)

if deployment:
await update_deployment_replicas(
apps,
namespace,
f"{GRAND_CENTRAL_RESOURCE_PREFIX}-{name}",
0,
)
await send_operation_progress_notification(
namespace=namespace,
name=name,
Expand All @@ -762,6 +746,20 @@ async def suspend_or_start_cluster(
await delete_lb_service(core, namespace, name)


async def suspend_or_start_grand_central(
apps: AppsV1Api, namespace: str, name: str, suspend: bool
):
deployment = await read_grand_central_deployment(namespace=namespace, name=name)

if deployment:
await update_deployment_replicas(
apps,
namespace,
f"{GRAND_CENTRAL_RESOURCE_PREFIX}-{name}",
0 if suspend else 1,
)


async def _get_connection_factory(core, namespace: str, name: str):
"""
Returns a connection factory.
Expand Down
263 changes: 239 additions & 24 deletions crate/operator/restore_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@
import asyncio
import logging
import re
from contextlib import asynccontextmanager
from dataclasses import fields
from typing import Any, Dict, List, Optional, Tuple

import kopf
from aiopg import Cursor
from kubernetes_asyncio.client import ApiException, CoreV1Api, CustomObjectsApi
from kubernetes_asyncio.client import (
ApiException,
AppsV1Api,
CoreV1Api,
CustomObjectsApi,
)
from psycopg2 import DatabaseError, ProgrammingError
from psycopg2.errors import DuplicateTable
from psycopg2.errors import DuplicateTable, UndefinedTable
from psycopg2.extensions import AsIs, QuotedString, quote_ident

from crate.operator.config import config
Expand All @@ -52,6 +58,7 @@
get_crash_scheme,
run_crash_command,
scale_backup_metrics_deployment,
suspend_or_start_grand_central,
)
from crate.operator.restore_backup_repository_data import BackupRepositoryData
from crate.operator.utils import crate
Expand Down Expand Up @@ -370,23 +377,7 @@ def get_restore_keyword(self, *, cursor: Cursor):
if not tables or (len(tables) == 1 and tables[0].lower() == "all"):
return "ALL"

def quote_table(table):
"""
Ensure table names are correctly quoted. If it contains a schema
(e.g., 'doc.nyc_taxi'), quote both the schema and the table using
psycopg2.extensions.quote_ident.
"""
if "." in table:
schema, table_name = table.split(".", 1)
else:
schema, table_name = None, table

quoted_schema = quote_ident(schema, cursor._impl) if schema else None
quoted_table = quote_ident(table_name, cursor._impl)

return f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table

formatted_tables = [quote_table(table.strip()) for table in tables]
formatted_tables = [quote_table(table.strip(), cursor) for table in tables]

return f'TABLE {",".join(formatted_tables)}'

Expand Down Expand Up @@ -577,6 +568,7 @@ async def handle( # type: ignore
):
async with GlobalApiClient() as api_client:
core = CoreV1Api(api_client)
apps = AppsV1Api(api_client)
data = await get_source_backup_repository_data(
core,
namespace,
Expand All @@ -596,16 +588,35 @@ async def handle( # type: ignore
conn_factory, repository, snapshot, logger
)

await self._start_restore_snapshot(
async with restore_internal_tables_context(
apps,
namespace,
name,
conn_factory,
repository,
snapshot,
restore_type,
logger,
restore_type,
tables,
partitions,
sections,
)
) as internal_tables:
await internal_tables.rename_duplicated_tables()

try:
await self._start_restore_snapshot(
conn_factory,
repository,
snapshot,
restore_type,
logger,
tables,
partitions,
sections,
)
except Exception as e:
await internal_tables.restore_tables()
raise e
else:
await internal_tables.cleanup_tables()

@staticmethod
async def _create_backup_repository(
Expand Down Expand Up @@ -1114,3 +1125,207 @@ async def handle( # type: ignore
name=name,
body=body,
)


@asynccontextmanager
async def restore_internal_tables_context(
apps,
namespace,
name,
conn_factory,
repository,
snapshot,
logger,
restore_type,
tables,
):
internal_tables = RestoreInternalTables(conn_factory, repository, snapshot, logger)
await internal_tables.set_gc_tables(restore_type, tables)
if internal_tables.has_tables_to_process():
logger.info("Suspending GC operations before restoring internal tables")
await suspend_or_start_grand_central(apps, namespace, name, suspend=True)
try:
yield internal_tables
finally:
if internal_tables.has_tables_to_process():
logger.info("Resuming GC operations after restoring internal tables")
await suspend_or_start_grand_central(apps, namespace, name, suspend=False)


class RestoreInternalTables:

def __init__(
self,
conn_factory,
repository: str,
snapshot: str,
logger: logging.Logger,
):
self.conn_factory = conn_factory
self.repository: str = repository
self.snapshot: str = snapshot
self.logger: logging.Logger = logger

self.gc_tables: list[str] = []

def has_tables_to_process(self) -> bool:
return True if self.gc_tables else False

async def set_gc_tables(
self, restore_type: str, tables: Optional[list[str]] = None
):
"""
Retrieve the grand central tables from the snapshot to be restored.
"""

if restore_type not in [
SnapshotRestoreType.ALL.value,
SnapshotRestoreType.TABLES.value,
]:
return

if restore_type == SnapshotRestoreType.TABLES.value and tables is not None:
gc_tables = [table for table in tables if table.startswith("gc.")]

# There is no gc table to restore, no need to proceed further
if not gc_tables:
return

tables_str = ",".join(f"'{table}'" for table in gc_tables)
where_stmt = f"t IN ({tables_str})"
else:
where_stmt = "t LIKE 'gc.%%'"

try:
async with self.conn_factory() as conn:
async with conn.cursor(timeout=120) as cursor:
await cursor.execute(
"WITH tables AS ("
" SELECT unnest(tables) AS t "
" FROM sys.snapshots "
" WHERE repository=%s AND name=%s"
") "
f"SELECT * FROM tables WHERE {where_stmt};",
(self.repository, self.snapshot),
)
snapshot_gc_tables = await cursor.fetchall()

if snapshot_gc_tables:
await cursor.execute('SHOW TABLES FROM "gc";')
existing_gc_tables = await cursor.fetchall()

if existing_gc_tables:
existing_gc_tables = [
f"gc.{table[0]}" for table in existing_gc_tables
]
for (table,) in snapshot_gc_tables:
if table in existing_gc_tables:
self.gc_tables.append(table)

except DatabaseError as e:
self.logger.warning(
"DatabaseError in RestoreInternalTables.set_gc_tables",
exc_info=e,
)
raise kopf.PermanentError("internal tables couldn't be retrieved.")

async def _rename_table(self, cursor, old_name: str, new_name: str):
self.logger.info(f"Renaming GC table: {old_name} to {new_name}")
try:
await cursor.execute(f"ALTER TABLE {old_name} RENAME TO {new_name};")
except UndefinedTable:
self.logger.warning(f"Table {old_name} does not exist. Skipping.")
pass

async def rename_duplicated_tables(self):
"""
If the snapshot contains grand central tables, rename them if they exist
in the cluster in order to recreate the new ones from the snapshot.
"""
if not self.has_tables_to_process():
return

try:
async with self.conn_factory() as conn:
async with conn.cursor(timeout=120) as cursor:
for table in self.gc_tables:
table_name = quote_table(table, cursor)
temp_table_name = table_without_schema(f"{table}_temp", cursor)
await self._rename_table(cursor, table_name, temp_table_name)
except DatabaseError as e:
self.logger.warning(
"DatabaseError in RestoreInternalTables.rename_duplicated_tables",
exc_info=e,
)
raise kopf.PermanentError("internal tables couldn't be renamed.")

async def restore_tables(self):
"""
If the restore operation failed, rename back the gc tables
to their original names.
"""
if not self.has_tables_to_process():
return

try:
async with self.conn_factory() as conn:
async with conn.cursor(timeout=120) as cursor:
for table in self.gc_tables:
table_name = table_without_schema(table, cursor)
temp_table_name = quote_table(f"{table}_temp", cursor)
await self._rename_table(cursor, temp_table_name, table_name)

except DatabaseError as e:
self.logger.warning(
"DatabaseError in RestoreInternalTables.restore_tables", exc_info=e
)
raise kopf.PermanentError("internal table couldn't be renamed.")

async def cleanup_tables(self):
"""
After a successful restore, the temporary renamed gc tables can be dropped.
"""
if not self.has_tables_to_process():
return

try:
async with self.conn_factory() as conn:
async with conn.cursor(timeout=120) as cursor:
for table in self.gc_tables:
temp_table_name = quote_table(f"{table}_temp", cursor)
self.logger.info(f"Dropping old GC table: {temp_table_name}")
await cursor.execute(f"DROP TABLE IF EXISTS {temp_table_name};")
except DatabaseError as e:
self.logger.warning(
"DatabaseError in RestoreInternalTables.cleanup_tables", exc_info=e
)
raise kopf.PermanentError("internal temporary table couldn't be dropped.")


def quote_table(table, cursor) -> str:
"""
Ensure table names are correctly quoted. If it contains a schema
(e.g., 'doc.nyc_taxi'), quote both the schema and the table using
psycopg2.extensions.quote_ident.
"""
if "." in table:
schema, table_name = table.split(".", 1)
else:
schema, table_name = None, table

quoted_schema = quote_ident(schema, cursor._impl) if schema else None
quoted_table = quote_ident(table_name, cursor._impl)

return f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table


def table_without_schema(table, cursor) -> str:
"""
Returns the table name without schema, ensuring it's correctly quoted.

:param table: The full table name, possibly including schema.
:param cursor: The database cursor used for quoting.
:return: The quoted table name without schema.
"""
table_name = table.split(".")[1] if "." in table else table
return quote_ident(table_name, cursor._impl)
Loading