Skip to content

Commit a0eeecb

Browse files
committed
property handle table's schema
1 parent 35b7510 commit a0eeecb

File tree

2 files changed

+107
-70
lines changed

2 files changed

+107
-70
lines changed

crate/operator/restore_backup.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -370,23 +370,7 @@ def get_restore_keyword(self, *, cursor: Cursor):
370370
if not tables or (len(tables) == 1 and tables[0].lower() == "all"):
371371
return "ALL"
372372

373-
def quote_table(table):
374-
"""
375-
Ensure table names are correctly quoted. If it contains a schema
376-
(e.g., 'doc.nyc_taxi'), quote both the schema and the table using
377-
psycopg2.extensions.quote_ident.
378-
"""
379-
if "." in table:
380-
schema, table_name = table.split(".", 1)
381-
else:
382-
schema, table_name = None, table
383-
384-
quoted_schema = quote_ident(schema, cursor._impl) if schema else None
385-
quoted_table = quote_ident(table_name, cursor._impl)
386-
387-
return f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table
388-
389-
formatted_tables = [quote_table(table.strip()) for table in tables]
373+
formatted_tables = [quote_table(table.strip(), cursor) for table in tables]
390374

391375
return f'TABLE {",".join(formatted_tables)}'
392376

@@ -1157,7 +1141,7 @@ async def remove_duplicated_tables(self, tables: Optional[List[str]] = None):
11571141
async with self.conn_factory() as conn:
11581142
async with conn.cursor(timeout=120) as cursor:
11591143
if tables is not None:
1160-
gc_tables = self.get_gc_tables(cursor, tables)
1144+
gc_tables = [t for t in tables if t.startswith("gc.")]
11611145
tables_str = ",".join(f"'{table}'" for table in gc_tables)
11621146
where_stmt = f"t IN ({tables_str})"
11631147
else:
@@ -1176,8 +1160,10 @@ async def remove_duplicated_tables(self, tables: Optional[List[str]] = None):
11761160
self.gc_tables = [table[0] for table in tables] if tables else []
11771161
for table in self.gc_tables:
11781162
self.logger.info(f"Renaming GC table: {table} to {table}_temp")
1163+
table_name = quote_table(table, cursor)
1164+
temp_table_name = table_without_schema(f"{table}_temp", cursor)
11791165
await cursor.execute(
1180-
f"ALTER TABLE {table} RENAME TO {table}_temp;"
1166+
f"ALTER TABLE {table_name} RENAME TO {temp_table_name};"
11811167
)
11821168
except DatabaseError as e:
11831169
self.logger.warning(
@@ -1199,8 +1185,10 @@ async def restore_tables(self):
11991185
async with conn.cursor(timeout=120) as cursor:
12001186
for table in self.gc_tables:
12011187
self.logger.info(f"Renaming GC table: {table}_temp to {table}")
1188+
table_name = table_without_schema(table, cursor)
1189+
temp_table_name = quote_table(f"{table}_temp", cursor)
12021190
await cursor.execute(
1203-
f"ALTER TABLE {table}_temp RENAME TO {table};"
1191+
f"ALTER TABLE {temp_table_name} RENAME TO {table_name};"
12041192
)
12051193
except DatabaseError as e:
12061194
self.logger.warning(
@@ -1220,17 +1208,39 @@ async def cleanup_tables(self):
12201208
async with conn.cursor(timeout=120) as cursor:
12211209
for table in self.gc_tables:
12221210
self.logger.info(f"Dropping old GC table: {table}_temp")
1223-
await cursor.execute(f"DROP TABLE {table}_temp;")
1211+
temp_table_name = quote_table(f"{table}_temp", cursor)
1212+
await cursor.execute(f"DROP TABLE {temp_table_name};")
12241213
except DatabaseError as e:
12251214
self.logger.warning(
12261215
"DatabaseError in RestoreGCTables.restore_tables", exc_info=e
12271216
)
12281217
raise kopf.PermanentError("grand-central table couldn't be renamed.")
12291218

1230-
@staticmethod
1231-
def get_gc_tables(cursor, tables: list[str]) -> list[str]:
1232-
return [
1233-
quote_ident(table, cursor._impl)
1234-
for table in tables
1235-
if table.startswith("gc.")
1236-
]
1219+
1220+
def quote_table(table, cursor) -> str:
1221+
"""
1222+
Ensure table names are correctly quoted. If it contains a schema
1223+
(e.g., 'doc.nyc_taxi'), quote both the schema and the table using
1224+
psycopg2.extensions.quote_ident.
1225+
"""
1226+
if "." in table:
1227+
schema, table_name = table.split(".", 1)
1228+
else:
1229+
schema, table_name = None, table
1230+
1231+
quoted_schema = quote_ident(schema, cursor._impl) if schema else None
1232+
quoted_table = quote_ident(table_name, cursor._impl)
1233+
1234+
return f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table
1235+
1236+
1237+
def table_without_schema(table, cursor) -> str:
1238+
"""
1239+
Returns the table name without schema, ensuring it's correctly quoted..
1240+
1241+
:param table: The full table name, possibly including schema.
1242+
:param cursor: The database cursor used for quoting.
1243+
:return: The quoted table name without schema.
1244+
"""
1245+
table_name = table.split(".")[1] if "." in table else table
1246+
return quote_ident(table_name, cursor._impl)

tests/test_restore_backup.py

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ def backup_repository_data(faker):
9494
}
9595

9696

97+
@pytest.fixture
98+
def mock_quote_ident():
99+
100+
def mock_quote_ident(value, connection):
101+
if value.startswith('"') and value.endswith('"'):
102+
return value
103+
return f'"{value}"'
104+
105+
with mock.patch(
106+
"crate.operator.restore_backup.quote_ident", side_effect=mock_quote_ident
107+
):
108+
yield
109+
110+
97111
@pytest.mark.k8s
98112
@pytest.mark.asyncio
99113
@mock.patch("crate.operator.webhooks.webhook_client._send")
@@ -353,6 +367,7 @@ async def test_restore_backup_aws(
353367
@mock.patch("crate.operator.webhooks.webhook_client._send")
354368
@mock.patch.object(RestoreBackupSubHandler, "_create_backup_repository")
355369
@mock.patch.object(RestoreBackupSubHandler, "_ensure_snapshot_exists")
370+
@mock.patch.object(RestoreBackupSubHandler, "_start_restore_snapshot")
356371
@mock.patch.object(RestoreInternalTables, "remove_duplicated_tables")
357372
@mock.patch.object(RestoreInternalTables, "cleanup_tables")
358373
@pytest.mark.parametrize("gc_enabled", [True, False])
@@ -782,24 +797,17 @@ async def test_restore_backup_create_repo_fails(
782797
),
783798
],
784799
)
785-
def test_get_restore_type_keyword(restore_type, expected_keyword, params):
800+
def test_get_restore_type_keyword(
801+
restore_type, expected_keyword, params, mock_quote_ident
802+
):
786803
cursor = mock.AsyncMock()
787-
788-
def mock_quote_ident(value, connection):
789-
if value.startswith('"') and value.endswith('"'):
790-
return value
791-
return f'"{value}"'
792-
793-
with mock.patch(
794-
"crate.operator.restore_backup.quote_ident", side_effect=mock_quote_ident
795-
):
796-
func_kwargs = {}
797-
if params:
798-
func_kwargs[restore_type.value] = params
799-
restore_keyword = RestoreType.create(
800-
restore_type.value, **func_kwargs
801-
).get_restore_keyword(cursor=cursor)
802-
assert restore_keyword == expected_keyword
804+
func_kwargs = {}
805+
if params:
806+
func_kwargs[restore_type.value] = params
807+
restore_keyword = RestoreType.create(
808+
restore_type.value, **func_kwargs
809+
).get_restore_keyword(cursor=cursor)
810+
assert restore_keyword == expected_keyword
803811

804812

805813
@pytest.mark.asyncio
@@ -981,8 +989,10 @@ async def does_credentials_secret_exist(
981989
def replace_gc_tables_data(faker, mock_cratedb_connection):
982990
repository = faker.domain_word()
983991
snapshot = faker.domain_word()
984-
table_a = f"gc.{faker.domain_word()}"
985-
table_b = f"gc.{faker.domain_word()}"
992+
table_a = faker.domain_word()
993+
table_b = faker.domain_word()
994+
tables = [table_a, table_b]
995+
tables_with_schema = [f"gc.{t}" for t in tables]
986996

987997
mock_cursor = mock_cratedb_connection["mock_cursor"]
988998
mock_logger = mock.Mock(spec=logging.Logger)
@@ -995,8 +1005,8 @@ def replace_gc_tables_data(faker, mock_cratedb_connection):
9951005
return (
9961006
repository,
9971007
snapshot,
998-
table_a,
999-
table_b,
1008+
tables,
1009+
tables_with_schema,
10001010
mock_cursor,
10011011
gc_tables_cls,
10021012
)
@@ -1006,18 +1016,18 @@ def replace_gc_tables_data(faker, mock_cratedb_connection):
10061016
@pytest.mark.parametrize(
10071017
"gc_enabled, type_all", [(True, False), (False, True), (True, True), (False, False)]
10081018
)
1009-
async def test_replace_gc_tables(gc_enabled, type_all, replace_gc_tables_data):
1010-
repository, snapshot, table_a, table_b, mock_cursor, gc_tables_cls = (
1019+
async def test_replace_gc_tables(
1020+
gc_enabled, type_all, replace_gc_tables_data, mock_quote_ident
1021+
):
1022+
repository, snapshot, tables, tables_with_schema, mock_cursor, gc_tables_cls = (
10111023
replace_gc_tables_data
10121024
)
1013-
mock_cursor.fetchall.return_value = [(table_a,), (table_b,)] if gc_enabled else []
10141025

1015-
tables_param = None if type_all else [table_a, table_b]
1026+
fetch_response = [(t,) for t in tables_with_schema]
1027+
mock_cursor.fetchall.return_value = fetch_response if gc_enabled else []
1028+
tables_param = None if type_all else tables_with_schema
10161029

1017-
with mock.patch(
1018-
"crate.operator.restore_backup.quote_ident", side_effect=[table_a, table_b]
1019-
):
1020-
await gc_tables_cls.remove_duplicated_tables(tables_param)
1030+
await gc_tables_cls.remove_duplicated_tables(tables_param)
10211031

10221032
if type_all:
10231033
stmts = [
@@ -1039,26 +1049,35 @@ async def test_replace_gc_tables(gc_enabled, type_all, replace_gc_tables_data):
10391049
" FROM sys.snapshots "
10401050
" WHERE repository=%s AND name=%s"
10411051
") "
1042-
f"SELECT * FROM tables WHERE t IN ('{table_a}','{table_b}');",
1052+
"SELECT * FROM tables WHERE t IN "
1053+
f"('{tables_with_schema[0]}','{tables_with_schema[1]}');",
10431054
(repository, snapshot),
10441055
),
10451056
]
10461057

10471058
if gc_enabled:
1048-
stmts.append(mock.call(f"ALTER TABLE {table_a} RENAME TO {table_a}_temp;"))
1049-
stmts.append(mock.call(f"ALTER TABLE {table_b} RENAME TO {table_b}_temp;"))
1059+
stmts.append(
1060+
mock.call(f'ALTER TABLE "gc"."{tables[0]}" RENAME TO "{tables[0]}_temp";')
1061+
)
1062+
stmts.append(
1063+
mock.call(f'ALTER TABLE "gc"."{tables[1]}" RENAME TO "{tables[1]}_temp";')
1064+
)
10501065

10511066
mock_cursor.execute.assert_has_awaits(stmts)
10521067
assert gc_tables_cls.gc_tables_renamed is True
10531068

10541069

10551070
@pytest.mark.asyncio
10561071
@pytest.mark.parametrize("gc_tables_renamed", [True, False])
1057-
async def test_restore_gc_tables(gc_tables_renamed, replace_gc_tables_data):
1058-
_, _, table_a, table_b, mock_cursor, gc_tables_cls = replace_gc_tables_data
1072+
async def test_restore_gc_tables(
1073+
gc_tables_renamed, replace_gc_tables_data, mock_quote_ident
1074+
):
1075+
_, _, tables, tables_with_schema, mock_cursor, gc_tables_cls = (
1076+
replace_gc_tables_data
1077+
)
10591078

10601079
gc_tables_cls.gc_tables_renamed = gc_tables_renamed
1061-
gc_tables_cls.gc_tables = [table_a, table_b]
1080+
gc_tables_cls.gc_tables = tables_with_schema
10621081

10631082
await gc_tables_cls.restore_tables()
10641083

@@ -1067,19 +1086,27 @@ async def test_restore_gc_tables(gc_tables_renamed, replace_gc_tables_data):
10671086
else:
10681087
mock_cursor.execute.assert_has_awaits(
10691088
[
1070-
mock.call(f"ALTER TABLE {table_a}_temp RENAME TO {table_a};"),
1071-
mock.call(f"ALTER TABLE {table_b}_temp RENAME TO {table_b};"),
1089+
mock.call(
1090+
f'ALTER TABLE "gc"."{tables[0]}_temp" RENAME TO "{tables[0]}";'
1091+
),
1092+
mock.call(
1093+
f'ALTER TABLE "gc"."{tables[1]}_temp" RENAME TO "{tables[1]}";'
1094+
),
10721095
]
10731096
)
10741097

10751098

10761099
@pytest.mark.asyncio
10771100
@pytest.mark.parametrize("gc_tables_renamed", [True, False])
1078-
async def test_cleanup_gc_tables(gc_tables_renamed, replace_gc_tables_data):
1079-
_, _, table_a, table_b, mock_cursor, gc_tables_cls = replace_gc_tables_data
1101+
async def test_cleanup_gc_tables(
1102+
gc_tables_renamed, replace_gc_tables_data, mock_quote_ident
1103+
):
1104+
_, _, tables, tables_with_schema, mock_cursor, gc_tables_cls = (
1105+
replace_gc_tables_data
1106+
)
10801107

10811108
gc_tables_cls.gc_tables_renamed = gc_tables_renamed
1082-
gc_tables_cls.gc_tables = [table_a, table_b]
1109+
gc_tables_cls.gc_tables = tables_with_schema
10831110

10841111
await gc_tables_cls.cleanup_tables()
10851112

@@ -1088,7 +1115,7 @@ async def test_cleanup_gc_tables(gc_tables_renamed, replace_gc_tables_data):
10881115
else:
10891116
mock_cursor.execute.assert_has_awaits(
10901117
[
1091-
mock.call(f"DROP TABLE {table_a}_temp;"),
1092-
mock.call(f"DROP TABLE {table_b}_temp;"),
1118+
mock.call(f'DROP TABLE "gc"."{tables[0]}_temp";'),
1119+
mock.call(f'DROP TABLE "gc"."{tables[1]}_temp";'),
10931120
]
10941121
)

0 commit comments

Comments
 (0)