Skip to content

Commit 5f0b62b

Browse files
committed
Added methods needed to expire snapshots by id, and optionally cleanup data
1 parent 0a94d96 commit 5f0b62b

File tree

3 files changed

+144
-152
lines changed

3 files changed

+144
-152
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,5 @@ htmlcov
5050
pyiceberg/avro/decoder_fast.c
5151
pyiceberg/avro/*.html
5252
pyiceberg/avro/*.so
53+
.vscode/settings.json
54+
pyiceberg/table/update/expire_snapshot.md

pyiceberg/table/update/snapshot.py

Lines changed: 70 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary:
247247
truncate_full_table=self._operation == Operation.OVERWRITE,
248248
)
249249

250-
def _commit(self) -> UpdatesAndRequirements:
250+
def commit(self, base_metadata: TableMetadata) -> UpdatesAndRequirements:
251251
new_manifests = self._manifests()
252252
next_sequence_number = self._transaction.table_metadata.next_sequence_number()
253253

@@ -748,6 +748,8 @@ class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
748748
ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
749749
"""
750750

751+
_snapshot_ids_to_expire: Set[int] = set()
752+
751753
_updates: Tuple[TableUpdate, ...] = ()
752754
_requirements: Tuple[TableRequirement, ...] = ()
753755

@@ -853,102 +855,80 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
853855
"""
854856
return self._remove_ref_snapshot(ref_name=branch_name)
855857

856-
857-
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
858-
def __init__(self, table: Table):
859-
super().__init__(table)
860-
self._expire_older_than: Optional[int] = None
861-
self._snapshot_ids_to_expire: Set[int] = set()
862-
self._retain_last: Optional[int] = None
863-
self._delete_func: Optional[Callable[[str], None]] = None
864-
865-
def expire_older_than(self, timestamp_ms: int) -> "ExpireSnapshots":
866-
"""Expire snapshots older than the given timestamp."""
867-
self._expire_older_than = timestamp_ms
868-
return self
869-
870-
def expire_snapshot_id(self, snapshot_id: int) -> "ExpireSnapshots":
858+
def _get_snapshot_ref_name(self, snapshot_id: int) -> Optional[str]:
859+
"""Get the reference name of a snapshot."""
860+
for ref_name, snapshot in self._transaction.table_metadata.refs.items():
861+
if snapshot.snapshot_id == snapshot_id:
862+
return ref_name
863+
return None
864+
865+
def _check_forward_ref(self, snapshot_id: int) -> bool:
866+
"""Check if the snapshot ID is a forward reference."""
867+
# Ensure that remaining snapshots correctly reference their parent
868+
for ref in self._transaction.table_metadata.refs.values():
869+
if ref.snapshot_id == snapshot_id:
870+
parent_snapshot_id = ref.parent_snapshot_id
871+
if parent_snapshot_id is not None and parent_snapshot_id not in self._transaction.table_metadata.snapshots:
872+
return False
873+
return True
874+
875+
def _find_dependant_snapshot(self, snapshot_id: int) -> Optional[int]:
876+
"""Find any dependant snapshot."""
877+
for ref in self._transaction.table_metadata.refs.values():
878+
if ref.snapshot_id == snapshot_id:
879+
return ref.parent_snapshot_id
880+
return None
881+
882+
def exipre_snapshot_by_id(self, snapshot_id: int) -> ManageSnapshots:
871883
"""Explicitly expire a snapshot by its ID."""
872884
self._snapshot_ids_to_expire.add(snapshot_id)
873885
return self
874886

875-
def retain_last(self, num_snapshots: int) -> "ExpireSnapshots":
876-
"""Retain the last N snapshots."""
877-
if num_snapshots < 1:
878-
raise ValueError("Number of snapshots to retain must be at least 1.")
879-
self._retain_last = num_snapshots
880-
return self
881-
882-
def delete_with(self, delete_func: Callable[[str], None]) -> "ExpireSnapshots":
883-
"""Set a custom delete function for cleaning up files."""
884-
self._delete_func = delete_func
885-
return self
886-
887-
def _commit(self, base_metadata: TableMetadata) -> UpdatesAndRequirements:
888-
snapshots_to_expire = set()
889-
890-
# Identify snapshots by timestamp
891-
if self._expire_older_than is not None:
892-
snapshots_to_expire.update(
893-
s.snapshot_id for s in base_metadata.snapshots
894-
if s.timestamp_ms < self._expire_older_than
887+
def expire_snapshots(self) -> ManageSnapshots:
888+
"""Expire the snapshots that are marked for expiration."""
889+
# iterate over each snapshot requested to be expired
890+
for snapshot_id in self._snapshot_ids_to_expire:
891+
# remove the reference to the snapshot in the table metadata
892+
# and stage the chagnes
893+
update, requirement = self._remove_ref_snapshot(
894+
ref_name=self._get_snapshot_ref_name(snapshot_id=snapshot_id),
895895
)
896896

897-
# Explicitly added snapshot IDs
898-
snapshots_to_expire.update(self._snapshot_ids_to_expire)
899-
900-
# Retain the last N snapshots
901-
if self._retain_last is not None:
902-
sorted_snapshots = sorted(base_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
903-
retained_snapshots = {s.snapshot_id for s in sorted_snapshots[:self._retain_last]}
904-
snapshots_to_expire.difference_update(retained_snapshots)
905-
906-
if not snapshots_to_expire:
907-
print("No snapshots identified for expiration.")
908-
return base_metadata # No change, return original metadata
909-
910-
print(f"Expiring snapshots: {snapshots_to_expire}")
911-
912-
# Filter snapshots
913-
remaining_snapshots = [
914-
snapshot for snapshot in base_metadata.snapshots
915-
if snapshot.snapshot_id not in snapshots_to_expire
916-
]
917-
918-
# Update snapshot log
919-
remaining_snapshot_log = [
920-
log for log in base_metadata.snapshot_log
921-
if log.snapshot_id not in snapshots_to_expire
922-
]
923-
924-
# Determine the new current snapshot ID
925-
new_current_snapshot_id = (
926-
max(remaining_snapshots, key=lambda s: s.timestamp_ms).snapshot_id
927-
if remaining_snapshots else None
928-
)
929-
930-
# Return new metadata object reflecting the expired snapshots
931-
updated_metadata = base_metadata.model_copy(
932-
update={
933-
"snapshots": remaining_snapshots,
934-
"snapshot_log": remaining_snapshot_log,
935-
"current_snapshot_id": new_current_snapshot_id
936-
}
937-
)
897+
# return the updates and requirements to be committed
898+
self._updates += update
899+
self._requirements += requirement
900+
901+
# check if there is a dependant snapshot
902+
dependant_snapshot_id = self._find_dependant_snapshot(snapshot_id=snapshot_id)
903+
if dependant_snapshot_id is not None:
904+
# remove the reference to the dependant snapshot in the table metadata
905+
# and stage the changes
906+
update, requirement = self._transaction._set_ref_snapshot(
907+
ref_name=self._get_snapshot_ref_name(snapshot_id=dependant_snapshot_id),
908+
snapshot_id=dependant_snapshot_id
909+
)
910+
self._updates += update
911+
self._requirements += requirement
938912

939-
# Cleanup orphaned files (manifests/data files)
940-
self._cleanup_files(snapshots_to_expire, base_metadata)
913+
# clean up the the unused files
941914

942-
return updated_metadata
915+
return self
943916

944-
def _cleanup_files(self, expired_snapshot_ids: Set[int], metadata: TableMetadata):
917+
def cleanup_files(self):
945918
"""Remove files no longer referenced by any snapshots."""
946-
print(f"Cleaning up resources for expired snapshots: {expired_snapshot_ids}")
947-
if self._delete_func:
948-
# Use the custom delete function if provided
949-
for snapshot_id in expired_snapshot_ids:
950-
self._delete_func(f"Snapshot {snapshot_id}")
951-
else:
952-
# Default cleanup logic (placeholder)
953-
for snapshot_id in expired_snapshot_ids:
954-
print(f"Default cleanup for snapshot {snapshot_id}")
919+
# Remove the manifest files for the expired snapshots
920+
for entry in self._snapshot_ids_to_expire:
921+
922+
# remove the manifest files for the expired snapshots
923+
for manifest in self._transaction._table.snapshot_by_id(entry).manifests(self._transaction._table.io):
924+
# get a list of all parquette files in the manifest that are orphaned
925+
data_files = manifest.fetch_manifest_entry(io=self._transaction._table.io, discard_deleted=True)
926+
927+
# remove the manfiest
928+
self._transaction._table.io.delete(manifest.manifest_path)
929+
930+
# remove the data files
931+
[self._transaction._table.io.delete(file.data_file.file_path) for file in data_files if file.data_file.file_path is not None]
932+
return self
933+
934+

tests/table/test_expire_snapshots.py

Lines changed: 72 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
1-
# pylint:disable=redefined-outer-name,eval-used
1+
# pylint:disable=redefined-outer-name
2+
# pylint:disable=redefined-outer-name
23
from unittest.mock import Mock
34
import pytest
45

5-
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile
6-
from pyiceberg.partitioning import PartitionField, PartitionSpec
7-
from pyiceberg.schema import Schema
86
from pyiceberg.table import Table
9-
from pyiceberg.table.metadata import TableMetadata
10-
from pyiceberg.table.snapshots import Operation, Snapshot, SnapshotLogEntry, SnapshotSummaryCollector, Summary, update_snapshot_summaries
7+
from pyiceberg.table.metadata import new_table_metadata
8+
from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry
119
from pyiceberg.table.update.snapshot import ExpireSnapshots
12-
from pyiceberg.transforms import IdentityTransform
13-
from pyiceberg.typedef import Record
14-
from pyiceberg.types import (
15-
BooleanType,
16-
IntegerType,
17-
NestedField,
18-
StringType,
19-
)
10+
11+
from pyiceberg.schema import Schema
12+
from pyiceberg.partitioning import PartitionSpec
13+
from pyiceberg.table.sorting import SortOrder
14+
15+
2016

2117
@pytest.fixture
2218
def mock_table():
23-
"""Fixture to create a mock table with metadata and snapshots."""
19+
"""
20+
Creates a mock Iceberg table with predefined metadata, snapshots, and snapshot log entries.
21+
The mock table includes:
22+
- Snapshots with unique IDs, timestamps, and manifest lists.
23+
- A snapshot log that tracks the history of snapshots with their IDs and timestamps.
24+
- Table metadata including schema, partition spec, sort order, location, properties, and UUID.
25+
- A current snapshot ID and last updated timestamp.
26+
Returns:
27+
Mock: A mock object representing an Iceberg table with the specified metadata and attributes.
28+
"""
2429
snapshots = [
2530
Snapshot(snapshot_id=1, timestamp_ms=1000, manifest_list="manifest1.avro"),
2631
Snapshot(snapshot_id=2, timestamp_ms=2000, manifest_list="manifest2.avro"),
@@ -31,56 +36,61 @@ def mock_table():
3136
SnapshotLogEntry(snapshot_id=2, timestamp_ms=2000),
3237
SnapshotLogEntry(snapshot_id=3, timestamp_ms=3000),
3338
]
34-
metadata = TableMetadata(
35-
snapshots=snapshots,
36-
snapshot_log=snapshot_log,
37-
current_snapshot_id=3,
39+
40+
metadata = new_table_metadata(
41+
schema=Schema(fields=[]),
42+
partition_spec=PartitionSpec(spec_id=0, fields=[]),
43+
sort_order=SortOrder(order_id=0, fields=[]),
44+
location="s3://example-bucket/path/",
45+
properties={},
46+
table_uuid="12345678-1234-1234-1234-123456789abc",
47+
).model_copy(
48+
update={
49+
"snapshots": snapshots,
50+
"snapshot_log": snapshot_log,
51+
"current_snapshot_id": 3,
52+
"last_updated_ms": 3000,
53+
}
3854
)
55+
3956
table = Mock(spec=Table)
4057
table.metadata = metadata
41-
return table
42-
43-
44-
def test_expire_older_than(mock_table):
45-
"""Test expiring snapshots older than a given timestamp."""
46-
expire_snapshots = ExpireSnapshots(mock_table)
47-
expire_snapshots.expire_older_than(2500)._commit(mock_table.metadata)
48-
49-
remaining_snapshot_ids = {s.snapshot_id for s in mock_table.metadata.snapshots}
50-
assert remaining_snapshot_ids == {3}, "Only the latest snapshot should remain."
58+
table.identifier = ("db", "table")
5159

5260

53-
def test_retain_last(mock_table):
54-
"""Test retaining the last N snapshots."""
55-
expire_snapshots = ExpireSnapshots(mock_table)
56-
expire_snapshots.retain_last(2)._commit(mock_table.metadata)
57-
58-
remaining_snapshot_ids = {s.snapshot_id for s in mock_table.metadata.snapshots}
59-
assert remaining_snapshot_ids == {2, 3}, "The last two snapshots should remain."
60-
61-
62-
def test_expire_specific_snapshot(mock_table):
63-
"""Test explicitly expiring a specific snapshot."""
64-
expire_snapshots = ExpireSnapshots(mock_table)
65-
expire_snapshots.expire_snapshot_id(2)._commit(mock_table.metadata)
66-
67-
remaining_snapshot_ids = {s.snapshot_id for s in mock_table.metadata.snapshots}
68-
assert remaining_snapshot_ids == {1, 3}, "Snapshot 2 should be expired."
69-
70-
71-
def test_custom_delete_function(mock_table):
72-
"""Test using a custom delete function for cleanup."""
73-
delete_func = Mock()
74-
expire_snapshots = ExpireSnapshots(mock_table)
75-
expire_snapshots.expire_snapshot_id(1).delete_with(delete_func)._commit(mock_table.metadata)
76-
77-
delete_func.assert_called_once_with("Snapshot 1"), "Custom delete function should be called for expired snapshot."
78-
79-
80-
def test_no_snapshots_to_expire(mock_table):
81-
"""Test when no snapshots are identified for expiration."""
82-
expire_snapshots = ExpireSnapshots(mock_table)
83-
expire_snapshots.expire_older_than(500)._commit(mock_table.metadata)
61+
return table
8462

85-
remaining_snapshot_ids = {s.snapshot_id for s in mock_table.metadata.snapshots}
86-
assert remaining_snapshot_ids == {1, 2, 3}, "No snapshots should be expired."
63+
def test_expire_snapshots_removes_correct_snapshots(mock_table: Mock):
64+
"""
65+
Test case for the `ExpireSnapshots` class to ensure that the correct snapshots
66+
are removed and the delete function is called the expected number of times.
67+
Args:
68+
mock_table (Mock): A mock object representing the table.
69+
Test Steps:
70+
1. Create a mock delete function and a mock transaction.
71+
2. Instantiate the `ExpireSnapshots` class with the mock transaction.
72+
3. Configure the `ExpireSnapshots` instance to expire snapshots with IDs 1 and 2,
73+
and set the delete function to the mock delete function.
74+
4. Commit the changes using the `_commit` method with the mock table's metadata.
75+
5. Validate that the mock delete function is called for the correct snapshots.
76+
6. Verify that the delete function is called exactly twice.
77+
7. Ensure that the updated metadata returned by `_commit` is not `None`.
78+
"""
79+
mock_delete_func = Mock()
80+
mock_transaction = Mock()
81+
82+
expire_snapshots = ExpireSnapshots(mock_transaction)
83+
expire_snapshots \
84+
.expire_snapshot_id(1) \
85+
.expire_snapshot_id(2) \
86+
.delete_with(mock_delete_func)
87+
88+
updated_metadata = expire_snapshots._commit(mock_table.metadata)
89+
90+
# Validate delete calls
91+
mock_delete_func.assert_any_call(mock_table.return_value.snapshots[0])
92+
mock_delete_func.assert_any_call(mock_table.metadata.snapshots[1])
93+
assert mock_delete_func.call_count == 2
94+
95+
# Verify updated metadata returned
96+
assert updated_metadata is not None

0 commit comments

Comments
 (0)