Skip to content

Commit e28815f

Browse files
committed
Snapshots are not being transacted on, but need to re-assign refs
ValueError: Cannot expire snapshot IDs {3051729675574597004} as they are currently referenced by table refs.
1 parent 65365e1 commit e28815f

File tree

4 files changed

+124
-68
lines changed

4 files changed

+124
-68
lines changed

pyiceberg/table/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
)
116116
from pyiceberg.table.update.schema import UpdateSchema
117117
from pyiceberg.table.update.snapshot import (
118+
ExpireSnapshots,
118119
ManageSnapshots,
119120
UpdateSnapshot,
120121
_FastAppendFiles,
@@ -1068,6 +1069,23 @@ def manage_snapshots(self) -> ManageSnapshots:
10681069
ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
10691070
"""
10701071
return ManageSnapshots(transaction=Transaction(self, autocommit=True))
1072+
1073+
def expire_snapshots(self) -> ExpireSnapshots:
1074+
"""
1075+
Shorthand to expire snapshots.
1076+
1077+
Use table.expire_snapshots().expire_snapshot_id(...).commit() or
1078+
table.expire_snapshots().expire_older_than(...).commit()
1079+
1080+
You can also use it inside a transaction context:
1081+
with table.transaction() as tx:
1082+
tx.expire_snapshots().expire_older_than(...)
1083+
1084+
"""
1085+
return ExpireSnapshots(Transaction(self, autocommit=True))
1086+
1087+
1088+
10711089

10721090
def update_statistics(self) -> UpdateStatistics:
10731091
"""

pyiceberg/table/update/__init__.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -575,20 +575,6 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta
575575

576576
return base_metadata.model_copy(update={"statistics": statistics})
577577

578-
@_apply_table_update.register(RemoveSnapshotsUpdate)
579-
def _(update: RemoveSnapshotsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
580-
if len(update.snapshot_ids) == 0 or len(base_metadata.snapshots) == 0:
581-
return base_metadata
582-
583-
retained_snapshots = []
584-
ids_to_remove = set(update.snapshot_ids)
585-
for snapshot in base_metadata.snapshots:
586-
if snapshot.snapshot_id not in ids_to_remove:
587-
retained_snapshots.append(snapshot)
588-
589-
context.add_update(update)
590-
return base_metadata.model_copy(update={"snapshots": retained_snapshots})
591-
592578

593579
def update_table_metadata(
594580
base_metadata: TableMetadata,

pyiceberg/table/update/snapshot.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@
9292
from typing import Optional, Set
9393
from datetime import datetime, timezone
9494

95+
from typing import Dict, Optional, Set
96+
import uuid
97+
from pyiceberg.table.metadata import TableMetadata
98+
from pyiceberg.table.snapshots import Snapshot
99+
from pyiceberg.table.update import (
100+
UpdateTableMetadata,
101+
RemoveSnapshotsUpdate,
102+
UpdatesAndRequirements,
103+
AssertRefSnapshotId,
104+
)
95105

96106
def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str:
97107
return f"{commit_uuid}-m{num}.avro"
@@ -860,50 +870,56 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
860870
"""
861871
API for removing old snapshots from the table.
862872
"""
863-
864-
_ids_to_remove: List[int] = []
873+
_updates: Tuple[TableUpdate, ...] = ()
874+
_requirements: Tuple[TableRequirement, ...] = ()
865875

866876
_updates: Tuple[TableUpdate, ...] = ()
867877
_requirements: Tuple[TableRequirement, ...] = ()
868878

879+
def __init__(self, transaction: Transaction) -> None:
880+
super().__init__(transaction)
881+
self._transaction = transaction
882+
self._ids_to_remove: Set[int] = set()
883+
869884
def _commit(self) -> UpdatesAndRequirements:
870-
return (RemoveSnapshotsUpdate(snapshot_ids=self._ids_to_remove),), ()
871-
872-
def _get_snapshot_ref_name(self, snapshot_id: int) -> Optional[str]:
873-
"""Get the reference name of a snapshot."""
874-
for ref_name, snapshot in self._transaction.table_metadata.refs.items():
875-
if snapshot.snapshot_id == snapshot_id:
876-
return ref_name
877-
return None
878-
879-
def _find_dependant_snapshot(self, snapshot_id: int) -> Optional[int]:
880-
"""Find any dependent snapshot."""
881-
for ref in self._transaction.table_metadata.refs.values():
882-
if ref.snapshot_id == snapshot_id:
883-
return ref.parent_snapshot_id
884-
return None
885+
"""Apply the pending changes and commit."""
886+
if not hasattr(self, "_transaction") or not self._transaction:
887+
raise AttributeError("Transaction object is not properly initialized.")
888+
889+
if not self._ids_to_remove:
890+
raise ValueError("No snapshot IDs marked for expiration.")
891+
892+
# Ensure current snapshots in refs are not marked for removal
893+
current_snapshot_ids = {ref.snapshot_id for ref in self._transaction.table_metadata.refs.values()}
894+
conflicting_ids = self._ids_to_remove.intersection(current_snapshot_ids)
895+
if conflicting_ids:
896+
raise ValueError(f"Cannot expire snapshot IDs {conflicting_ids} as they are currently referenced by table refs.")
897+
898+
updates = (RemoveSnapshotsUpdate(snapshot_ids=list(self._ids_to_remove)),)
899+
900+
# Ensure refs haven't changed (snapshot ID consistency check)
901+
requirements = tuple(
902+
AssertRefSnapshotId(snapshot_id=ref.snapshot_id, ref=ref_name)
903+
for ref_name, ref in self._transaction.table_metadata.refs.items()
904+
)
905+
self._updates += updates
906+
self._requirements += requirements
907+
return self
885908

886909
def expire_snapshot_id(self, snapshot_id_to_expire: int) -> ExpireSnapshots:
887910
"""Mark a specific snapshot ID for expiration."""
888-
if self._transaction._table.snapshot_by_id(snapshot_id_to_expire):
889-
self._ids_to_remove.append(snapshot_id_to_expire)
911+
snapshot = self._transaction._table.snapshot_by_id(snapshot_id_to_expire)
912+
if snapshot:
913+
self._ids_to_remove.add(snapshot_id_to_expire)
914+
else:
915+
raise ValueError(f"Snapshot ID {snapshot_id_to_expire} does not exist.")
890916
return self
891917

892918
def expire_older_than(self, timestamp_ms: int) -> ExpireSnapshots:
893919
"""Mark snapshots older than the given timestamp for expiration."""
894920
for snapshot in self._transaction.table_metadata.snapshots:
895921
if snapshot.timestamp_ms < timestamp_ms:
896-
self._ids_to_remove.append(snapshot.snapshot_id)
922+
self._ids_to_remove.add(snapshot.snapshot_id)
897923
return self
898924

899-
# Uncomment and implement cleanup_files if file cleanup is required
900-
# def cleanup_files(self):
901-
# """Remove files no longer referenced by any snapshots."""
902-
# for entry in self._ids_to_remove:
903-
# for manifest in self._transaction._table.snapshot_by_id(entry).manifests(self._transaction._table.io):
904-
# data_files = manifest.fetch_manifest_entry(io=self._transaction._table.io, discard_deleted=True)
905-
# self._transaction._table.io.delete(manifest.manifest_path)
906-
# [self._transaction._table.io.delete(file.data_file.file_path) for file in data_files if file.data_file.file_path is not None]
907-
# return self
908-
909925

tests/table/test_expire_snapshots.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, Dict, Optional
66
import pytest
77
from pyiceberg.catalog.memory import InMemoryCatalog
8-
from pyiceberg.catalog.noop import NoopCatalog
98
from pyiceberg.io import load_file_io
109
from pyiceberg.table import Table
1110
from pyiceberg.table.sorting import NullOrder, SortDirection, SortField, SortOrder
@@ -17,9 +16,46 @@
1716
from pyiceberg.schema import Schema
1817
from pyiceberg.types import NestedField, LongType, StringType
1918
from pyiceberg.table.snapshots import Snapshot
20-
from pyiceberg.table.metadata import TableMetadata, TableMetadataV2, new_table_metadata
19+
from pyiceberg.table.metadata import TableMetadata, TableMetadataUtil, TableMetadataV2, new_table_metadata
2120

2221

22+
@pytest.fixture
23+
def mock_table():
24+
"""Fixture to create a mock Table instance with proper metadata for testing."""
25+
# Create mock metadata with empty snapshots list
26+
metadata_dict = {
27+
"format-version": 2,
28+
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
29+
"location": "s3://bucket/test/location",
30+
"last-sequence-number": 0,
31+
"last-updated-ms": int(time.time() * 1000),
32+
"last-column-id": 3,
33+
"current-schema-id": 1,
34+
"schemas": [
35+
{
36+
"type": "struct",
37+
"schema-id": 1,
38+
"fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]
39+
}
40+
],
41+
"default-spec-id": 0,
42+
"partition-specs": [{"spec-id": 0, "fields": []}],
43+
"last-partition-id": 0,
44+
"default-sort-order-id": 0,
45+
"sort-orders": [{"order-id": 0, "fields": []}],
46+
"snapshots": [],
47+
"refs": {},
48+
}
49+
50+
metadata = TableMetadataUtil.parse_obj(metadata_dict)
51+
52+
return Table(
53+
identifier=("mock_database", "mock_table"),
54+
metadata=metadata,
55+
metadata_location="mock_location",
56+
io=load_file_io(),
57+
catalog=InMemoryCatalog("InMemoryCatalog"),
58+
)
2359

2460

2561
@pytest.fixture
@@ -43,19 +79,24 @@ def generate_snapshot(
4379
snapshot_log = []
4480
initial_snapshot_id = 3051729675574597004
4581

46-
for i in range(2000):
82+
for i in range(5):
4783
snapshot_id = initial_snapshot_id + i
4884
parent_snapshot_id = snapshot_id - 1 if i > 0 else None
4985
timestamp_ms = int(time.time() * 1000) - randint(0, 1000000)
5086
snapshots.append(generate_snapshot(snapshot_id, parent_snapshot_id, timestamp_ms, i))
5187
snapshot_log.append({"snapshot-id": snapshot_id, "timestamp-ms": timestamp_ms})
5288

53-
metadata = {
89+
metadata_dict = {
5490
"format-version": 2,
5591
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
5692
"location": "s3://bucket/test/location",
5793
"last-sequence-number": 34,
58-
"last-updated-ms": 1602638573590,
94+
"last-updated-ms": snapshots[-1]["timestamp-ms"],
95+
"metadata-log": [
96+
{"metadata-file": "s3://bucket/test/location/metadata/v1.json", "timestamp-ms": 1700000000000},
97+
{"metadata-file": "s3://bucket/test/location/metadata/v2.json", "timestamp-ms": 1700003600000},
98+
{"metadata-file": "s3://bucket/test/location/metadata/v3.json", "timestamp-ms": snapshots[-1]["timestamp-ms"]},
99+
],
59100
"last-column-id": 3,
60101
"current-schema-id": 1,
61102
"schemas": [
@@ -72,46 +113,41 @@ def generate_snapshot(
72113
},
73114
],
74115
"default-spec-id": 0,
75-
"partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}],
116+
"partition-specs": [{"spec-id": 0, "fields": []}],
76117
"last-partition-id": 1000,
77118
"default-sort-order-id": 3,
78-
"sort-orders": [
79-
{
80-
"order-id": 3,
81-
"fields": [
82-
{"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"},
83-
{"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"},
84-
],
85-
}
86-
],
119+
"sort-orders": [{"order-id": 3, "fields": []}],
87120
"properties": {"read.split.target.size": "134217728"},
88-
"current-snapshot-id": initial_snapshot_id + 1999,
121+
"current-snapshot-id": initial_snapshot_id + 4,
89122
"snapshots": snapshots,
90123
"snapshot-log": snapshot_log,
91-
"metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}],
92124
"refs": {"test": {"snapshot-id": initial_snapshot_id, "type": "tag", "max-ref-age-ms": 10000000}},
93125
}
94126

127+
metadata = TableMetadataUtil.parse_obj(metadata_dict)
128+
95129
return Table(
96130
identifier=("database", "table"),
97131
metadata=metadata,
98-
metadata_location=f"{metadata['location']}/uuid.metadata.json",
132+
metadata_location=f"{metadata.location}/uuid.metadata.json",
99133
io=load_file_io(),
100-
catalog=NoopCatalog("NoopCatalog"),
134+
catalog=InMemoryCatalog("InMemoryCatalog"),
101135
)
102136

103137

104138

105-
def test_expire_snapshots_removes_correct_snapshots(generate_test_table):
139+
def test_expire_snapshots_removes_correct_snapshots(generate_test_table: Table):
106140
"""
107141
Test case for the `ExpireSnapshots` class to ensure that the correct snapshots
108142
are removed and the delete function is called the expected number of times.
109143
"""
110-
144+
111145
# Use the fixture-provided table
112-
with ExpireSnapshots(generate_test_table.transaction()) as manage_snapshots:
113-
manage_snapshots.expire_snapshot_id(3051729675574597004)
146+
with generate_test_table.expire_snapshots() as transaction:
147+
transaction.expire_snapshot_id(3051729675574597004).commit()
114148

115149
# Check the remaining snapshots
116150
remaining_snapshot_ids = {snapshot.snapshot_id for snapshot in generate_test_table.metadata.snapshots}
117-
assert not remaining_snapshot_ids.issubset({3051729675574597004})
151+
152+
# Assert that the expired snapshot ID is not in the remaining snapshots
153+
assert 3051729675574597004 not in remaining_snapshot_ids

0 commit comments

Comments
 (0)