Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/psycopack/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,30 @@ def swap_pk_sequence_name(self, *, first_table: str, second_table: str) -> None:
self.rename_sequence(seq_from=second_seq, seq_to=first_seq)
self.rename_sequence(seq_from=temp_seq, seq_to=second_seq)

def transfer_pk_sequence_value(
self, *, source_table: str, dest_table: str, convert_pk_to_bigint: bool
) -> None:
source_seq = self.introspector.get_pk_sequence_name(table=source_table)
dest_seq = self.introspector.get_pk_sequence_name(table=dest_table)
value = self.introspector.get_pk_sequence_value(seq=source_seq)

if convert_pk_to_bigint and value < 0:
# special case handling where negative PK values were used before bigint conversion
value = 2**31 # reset to positive, specifically the first bigint value

# TODO: try to correctly restore a negative PK sequence value if we revert swap
# while doing a bigint conversion

self.cur.execute(
psycopg.sql.SQL("SELECT setval('{schema}.{sequence}', {value});")
.format(
schema=psycopg.sql.Identifier(self.schema),
sequence=psycopg.sql.Identifier(dest_seq),
value=psycopg.sql.SQL(str(value)),
)
.as_string(self.conn)
)

def acquire_access_exclusive_lock(self, *, table: str) -> None:
self.cur.execute(
psycopg.sql.SQL("LOCK TABLE {schema}.{table} IN ACCESS EXCLUSIVE MODE;")
Expand Down
15 changes: 15 additions & 0 deletions src/psycopack/_introspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,21 @@ def get_pk_sequence_name(self, *, table: str) -> str:
assert isinstance(seq, str)
return seq

def get_pk_sequence_value(self, *, seq: str) -> int:
self.cur.execute(
psycopg.sql.SQL("SELECT last_value FROM {schema}.{sequence};")
.format(
schema=psycopg.sql.Identifier(self.schema),
sequence=psycopg.sql.Identifier(seq),
)
.as_string(self.conn)
)
result = self.cur.fetchone()
assert result is not None
value = result[0]
assert isinstance(value, int)
return value

def get_backfill_batch(self, *, table: str) -> BackfillBatch | None:
self.cur.execute(
psycopg.sql.SQL(
Expand Down
11 changes: 11 additions & 0 deletions src/psycopack/_repack.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ def swap(self) -> None:
self.command.swap_pk_sequence_name(
first_table=self.table, second_table=self.copy_table
)
self.command.transfer_pk_sequence_value(
source_table=self.table,
dest_table=self.copy_table,
convert_pk_to_bigint=self.convert_pk_to_bigint,
)
self.command.rename_table(
table_from=self.table, table_to=self.repacked_name
)
Expand Down Expand Up @@ -434,6 +439,12 @@ def revert_swap(self) -> None:
self.command.swap_pk_sequence_name(
first_table=self.table, second_table=self.repacked_name
)
self.command.transfer_pk_sequence_value(
source_table=self.table,
dest_table=self.repacked_name,
convert_pk_to_bigint=self.convert_pk_to_bigint,
)

self.command.rename_table(table_from=self.table, table_to=self.copy_table)
self.command.rename_table(
table_from=self.repacked_name, table_to=self.table
Expand Down
9 changes: 6 additions & 3 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def create_table_for_repacking(
with_exclusion_constraint: bool = False,
pk_type: str = "SERIAL",
pk_name: str = "id",
pk_start: int = 1,
ommit_sequence: bool = False,
schema: str = "public",
) -> None:
Expand Down Expand Up @@ -45,7 +46,9 @@ def create_table_for_repacking(
):
# Create a sequence manually.
seq = f"{table_name}_seq"
cur.execute(f"CREATE SEQUENCE {schema}.{seq};")
cur.execute(
f"CREATE SEQUENCE {schema}.{seq} MINVALUE {pk_start} START WITH {pk_start};"
)
pk_type = f"{pk_type} DEFAULT NEXTVAL('{schema}.{seq}')"

cur.execute(
Expand Down Expand Up @@ -191,7 +194,7 @@ def create_table_for_repacking(
cur.execute(
dedent(f"""
INSERT INTO {schema}.referring_table ({table_name}_{pk_name})
SELECT generate_series(1, {referring_table_rows});
SELECT generate_series({pk_start}, {pk_start + referring_table_rows - 1});
""")
)
cur.execute(
Expand All @@ -213,7 +216,7 @@ def create_table_for_repacking(
cur.execute(
dedent(f"""
INSERT INTO {schema}.not_valid_referring_table ({table_name}_{pk_name})
SELECT generate_series(1, {referring_table_rows});
SELECT generate_series({pk_start}, {pk_start + referring_table_rows - 1});
""")
)

Expand Down
40 changes: 40 additions & 0 deletions tests/test_repack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class _TableInfo:
referring_fks: list[_introspect.ReferringForeignKey]
constraints: list[_introspect.Constraint]
pk_seq: str
pk_seq_val: int | None


def _collect_table_info(
Expand All @@ -60,13 +61,18 @@ def _collect_table_info(
table=table, types=["c", "f", "n", "p", "u", "t", "x"]
)
pk_seq = introspector.get_pk_sequence_name(table=table)
if pk_seq:
pk_seq_val = introspector.get_pk_sequence_value(seq=pk_seq)
else:
pk_seq_val = None

return _TableInfo(
oid=oid,
indexes=indexes,
referring_fks=referring_fks,
constraints=constraints,
pk_seq=pk_seq,
pk_seq_val=pk_seq_val,
)


Expand Down Expand Up @@ -127,6 +133,10 @@ def _assert_repack(
assert table_before.referring_fks == table_after.referring_fks
assert table_before.constraints == table_after.constraints
assert table_before.pk_seq == table_after.pk_seq
if table_before.pk_seq_val is None or table_before.pk_seq_val > 0:
assert table_before.pk_seq_val == table_after.pk_seq_val
else:
assert table_after.pk_seq_val == 2**31

# All functions and triggers are removed.
trigger_info = _get_trigger_info(repack, cur)
Expand Down Expand Up @@ -1283,6 +1293,36 @@ def test_repack_full_with_serial_pk(
)


def test_when_table_has_negative_pk_values(
connection: _psycopg.Connection,
) -> None:
with _cur.get_cursor(connection, logged=True) as cur:
factories.create_table_for_repacking(
connection=connection,
cur=cur,
table_name="to_repack",
rows=100,
pk_type="integer",
pk_start=-200,
)
table_before = _collect_table_info(table="to_repack", connection=connection)
repack = Psycopack(
table="to_repack",
batch_size=1,
conn=connection,
cur=cur,
convert_pk_to_bigint=True,
)
repack.full()
table_after = _collect_table_info(table="to_repack", connection=connection)
_assert_repack(
table_before=table_before,
table_after=table_after,
repack=repack,
cur=cur,
)


def test_when_table_has_large_value_being_inserted(
connection: _psycopg.Connection,
) -> None:
Expand Down
Loading