Skip to content

Commit 65365e1

Browse files
committed
Added the builder method to __init__.py, updated the snapshot api with a new Expired Snapshot class. updated tests.
1 parent f995daa commit 65365e1

File tree

4 files changed

+152
-121
lines changed

4 files changed

+152
-121
lines changed

pyiceberg/table/update/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,20 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta
575575

576576
return base_metadata.model_copy(update={"statistics": statistics})
577577

578+
@_apply_table_update.register(RemoveSnapshotsUpdate)
579+
def _(update: RemoveSnapshotsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
580+
if len(update.snapshot_ids) == 0 or len(base_metadata.snapshots) == 0:
581+
return base_metadata
582+
583+
retained_snapshots = []
584+
ids_to_remove = set(update.snapshot_ids)
585+
for snapshot in base_metadata.snapshots:
586+
if snapshot.snapshot_id not in ids_to_remove:
587+
retained_snapshots.append(snapshot)
588+
589+
context.add_update(update)
590+
return base_metadata.model_copy(update={"snapshots": retained_snapshots})
591+
578592

579593
def update_table_metadata(
580594
base_metadata: TableMetadata,

pyiceberg/table/update/note.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
in the snapshot.py class, you define the "api" or logic to collect the changes and then stage, them. Then the
2+
__init__.py has a decorator that calls the type to actually apply the metadata changes.

pyiceberg/table/update/snapshot.py

Lines changed: 37 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
AddSnapshotUpdate,
6767
AssertRefSnapshotId,
6868
RemoveSnapshotRefUpdate,
69+
RemoveSnapshotsUpdate,
6970
SetSnapshotRefUpdate,
7071
TableRequirement,
7172
TableMetadata,
@@ -84,9 +85,8 @@
8485
from pyiceberg.utils.properties import property_as_bool, property_as_int
8586

8687
if TYPE_CHECKING:
87-
from pyiceberg.table import Transaction
88+
from pyiceberg.table import Table
8889

89-
from pyiceberg.table import Table
9090
from pyiceberg.table.metadata import Snapshot
9191
from pyiceberg.table.update import UpdateTableMetadata
9292
from typing import Optional, Set
@@ -748,13 +748,13 @@ class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
748748
ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
749749
"""
750750

751-
_snapshot_ids_to_expire: Set[int] = set()
752-
753751
_updates: Tuple[TableUpdate, ...] = ()
754752
_requirements: Tuple[TableRequirement, ...] = ()
755753

756754
def _commit(self) -> UpdatesAndRequirements:
757755
"""Apply the pending changes and commit."""
756+
if not hasattr(self._transaction, "_apply"):
757+
raise AttributeError("Transaction object is not properly initialized.")
758758
return self._updates, self._requirements
759759

760760
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
@@ -855,80 +855,55 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
855855
"""
856856
return self._remove_ref_snapshot(ref_name=branch_name)
857857

858+
859+
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
860+
"""
861+
API for removing old snapshots from the table.
862+
"""
863+
864+
_ids_to_remove: List[int] = []
865+
866+
_updates: Tuple[TableUpdate, ...] = ()
867+
_requirements: Tuple[TableRequirement, ...] = ()
868+
869+
def _commit(self) -> UpdatesAndRequirements:
870+
return (RemoveSnapshotsUpdate(snapshot_ids=self._ids_to_remove),), ()
871+
858872
def _get_snapshot_ref_name(self, snapshot_id: int) -> Optional[str]:
859873
"""Get the reference name of a snapshot."""
860874
for ref_name, snapshot in self._transaction.table_metadata.refs.items():
861875
if snapshot.snapshot_id == snapshot_id:
862876
return ref_name
863877
return None
864878

865-
def _check_forward_ref(self, snapshot_id: int) -> bool:
866-
"""Check if the snapshot ID is a forward reference."""
867-
# Ensure that remaining snapshots correctly reference their parent
868-
for ref in self._transaction.table_metadata.refs.values():
869-
if ref.snapshot_id == snapshot_id:
870-
parent_snapshot_id = ref.parent_snapshot_id
871-
if parent_snapshot_id is not None and parent_snapshot_id not in self._transaction.table_metadata.snapshots:
872-
return False
873-
return True
874-
875879
def _find_dependant_snapshot(self, snapshot_id: int) -> Optional[int]:
876-
"""Find any dependant snapshot."""
880+
"""Find any dependent snapshot."""
877881
for ref in self._transaction.table_metadata.refs.values():
878882
if ref.snapshot_id == snapshot_id:
879883
return ref.parent_snapshot_id
880884
return None
881885

882-
def exipre_snapshot_by_id(self, snapshot_id: int) -> ManageSnapshots:
883-
"""Explicitly expire a snapshot by its ID."""
884-
self._snapshot_ids_to_expire.add(snapshot_id)
886+
def expire_snapshot_id(self, snapshot_id_to_expire: int) -> ExpireSnapshots:
887+
"""Mark a specific snapshot ID for expiration."""
888+
if self._transaction._table.snapshot_by_id(snapshot_id_to_expire):
889+
self._ids_to_remove.append(snapshot_id_to_expire)
885890
return self
886891

887-
def expire_snapshots(self) -> ManageSnapshots:
888-
"""Expire the snapshots that are marked for expiration."""
889-
# iterate over each snapshot requested to be expired
890-
for snapshot_id in self._snapshot_ids_to_expire:
891-
# remove the reference to the snapshot in the table metadata
892-
# and stage the chagnes
893-
update, requirement = self._remove_ref_snapshot(
894-
ref_name=self._get_snapshot_ref_name(snapshot_id=snapshot_id),
895-
)
896-
897-
# return the updates and requirements to be committed
898-
self._updates += update
899-
self._requirements += requirement
900-
901-
# check if there is a dependant snapshot
902-
dependant_snapshot_id = self._find_dependant_snapshot(snapshot_id=snapshot_id)
903-
if dependant_snapshot_id is not None:
904-
# remove the reference to the dependant snapshot in the table metadata
905-
# and stage the changes
906-
update, requirement = self._transaction._set_ref_snapshot(
907-
ref_name=self._get_snapshot_ref_name(snapshot_id=dependant_snapshot_id),
908-
snapshot_id=dependant_snapshot_id
909-
)
910-
self._updates += update
911-
self._requirements += requirement
912-
913-
# clean up the the unused files
914-
892+
def expire_older_than(self, timestamp_ms: int) -> ExpireSnapshots:
893+
"""Mark snapshots older than the given timestamp for expiration."""
894+
for snapshot in self._transaction.table_metadata.snapshots:
895+
if snapshot.timestamp_ms < timestamp_ms:
896+
self._ids_to_remove.append(snapshot.snapshot_id)
915897
return self
916898

917-
def cleanup_files(self):
918-
"""Remove files no longer referenced by any snapshots."""
919-
# Remove the manifest files for the expired snapshots
920-
for entry in self._snapshot_ids_to_expire:
921-
922-
# remove the manifest files for the expired snapshots
923-
for manifest in self._transaction._table.snapshot_by_id(entry).manifests(self._transaction._table.io):
924-
# get a list of all parquette files in the manifest that are orphaned
925-
data_files = manifest.fetch_manifest_entry(io=self._transaction._table.io, discard_deleted=True)
926-
927-
# remove the manfiest
928-
self._transaction._table.io.delete(manifest.manifest_path)
929-
930-
# remove the data files
931-
[self._transaction._table.io.delete(file.data_file.file_path) for file in data_files if file.data_file.file_path is not None]
932-
return self
899+
# Uncomment and implement cleanup_files if file cleanup is required
900+
# def cleanup_files(self):
901+
# """Remove files no longer referenced by any snapshots."""
902+
# for entry in self._ids_to_remove:
903+
# for manifest in self._transaction._table.snapshot_by_id(entry).manifests(self._transaction._table.io):
904+
# data_files = manifest.fetch_manifest_entry(io=self._transaction._table.io, discard_deleted=True)
905+
# self._transaction._table.io.delete(manifest.manifest_path)
906+
# [self._transaction._table.io.delete(file.data_file.file_path) for file in data_files if file.data_file.file_path is not None]
907+
# return self
933908

934909

tests/table/test_expire_snapshots.py

Lines changed: 99 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,117 @@
1-
# pylint:disable=redefined-outer-name
2-
# pylint:disable=redefined-outer-name
3-
from unittest.mock import Mock
1+
from datetime import datetime, timezone
2+
from pathlib import PosixPath
3+
from random import randint
4+
import time
5+
from typing import Any, Dict, Optional
46
import pytest
5-
7+
from pyiceberg.catalog.memory import InMemoryCatalog
8+
from pyiceberg.catalog.noop import NoopCatalog
9+
from pyiceberg.io import load_file_io
10+
from pyiceberg.table import Table
11+
from pyiceberg.table.sorting import NullOrder, SortDirection, SortField, SortOrder
12+
from pyiceberg.table.update.snapshot import ExpireSnapshots
13+
from pyiceberg.transforms import IdentityTransform
14+
from pyiceberg.types import BooleanType, FloatType, IntegerType, ListType, LongType, MapType, StructType
15+
from tests.catalog.test_base import InMemoryCatalog, Table
616
from pyiceberg.table import Table
7-
from pyiceberg.table.metadata import new_table_metadata
8-
from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry
9-
from pyiceberg.table.update.snapshot import ManageSnapshots
10-
1117
from pyiceberg.schema import Schema
12-
from pyiceberg.partitioning import PartitionSpec
13-
from pyiceberg.table.sorting import SortOrder
18+
from pyiceberg.types import NestedField, LongType, StringType
19+
from pyiceberg.table.snapshots import Snapshot
20+
from pyiceberg.table.metadata import TableMetadata, TableMetadataV2, new_table_metadata
1421

1522

1623

17-
@pytest.fixture
18-
def mock_table():
19-
"""
20-
Creates a mock Iceberg table with predefined metadata, snapshots, and snapshot log entries.
21-
The mock table includes:
22-
- Snapshots with unique IDs, timestamps, and manifest lists.
23-
- A snapshot log that tracks the history of snapshots with their IDs and timestamps.
24-
- Table metadata including schema, partition spec, sort order, location, properties, and UUID.
25-
- A current snapshot ID and last updated timestamp.
26-
Returns:
27-
Mock: A mock object representing an Iceberg table with the specified metadata and attributes.
28-
"""
29-
snapshots = [
30-
Snapshot(snapshot_id=1, timestamp_ms=1000, manifest_list="manifest1.avro"),
31-
Snapshot(snapshot_id=2, timestamp_ms=2000, manifest_list="manifest2.avro"),
32-
Snapshot(snapshot_id=3, timestamp_ms=3000, manifest_list="manifest3.avro"),
33-
]
34-
snapshot_log = [
35-
SnapshotLogEntry(snapshot_id=1, timestamp_ms=1000),
36-
SnapshotLogEntry(snapshot_id=2, timestamp_ms=2000),
37-
SnapshotLogEntry(snapshot_id=3, timestamp_ms=3000),
38-
]
3924

40-
metadata = new_table_metadata(
41-
schema=Schema(fields=[]),
42-
partition_spec=PartitionSpec(spec_id=0, fields=[]),
43-
sort_order=SortOrder(order_id=0, fields=[]),
44-
location="s3://example-bucket/path/",
45-
properties={},
46-
table_uuid="12345678-1234-1234-1234-123456789abc",
47-
).model_copy(
48-
update={
49-
"snapshots": snapshots,
50-
"snapshot_log": snapshot_log,
51-
"current_snapshot_id": 3,
52-
"last_updated_ms": 3000,
25+
@pytest.fixture
26+
def generate_test_table() -> Table:
27+
def generate_snapshot(
28+
snapshot_id: int,
29+
parent_snapshot_id: Optional[int] = None,
30+
timestamp_ms: Optional[int] = None,
31+
sequence_number: int = 0,
32+
) -> Dict[str, Any]:
33+
return {
34+
"snapshot-id": snapshot_id,
35+
"parent-snapshot-id": parent_snapshot_id,
36+
"timestamp-ms": timestamp_ms or int(time.time() * 1000),
37+
"sequence-number": sequence_number,
38+
"summary": {"operation": "append"},
39+
"manifest-list": f"s3://a/b/{snapshot_id}.avro",
5340
}
54-
)
5541

56-
table = Mock(spec=Table)
57-
table.metadata = metadata
58-
table.identifier = ("db", "table")
42+
snapshots = []
43+
snapshot_log = []
44+
initial_snapshot_id = 3051729675574597004
45+
46+
for i in range(2000):
47+
snapshot_id = initial_snapshot_id + i
48+
parent_snapshot_id = snapshot_id - 1 if i > 0 else None
49+
timestamp_ms = int(time.time() * 1000) - randint(0, 1000000)
50+
snapshots.append(generate_snapshot(snapshot_id, parent_snapshot_id, timestamp_ms, i))
51+
snapshot_log.append({"snapshot-id": snapshot_id, "timestamp-ms": timestamp_ms})
5952

53+
metadata = {
54+
"format-version": 2,
55+
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
56+
"location": "s3://bucket/test/location",
57+
"last-sequence-number": 34,
58+
"last-updated-ms": 1602638573590,
59+
"last-column-id": 3,
60+
"current-schema-id": 1,
61+
"schemas": [
62+
{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]},
63+
{
64+
"type": "struct",
65+
"schema-id": 1,
66+
"identifier-field-ids": [1, 2],
67+
"fields": [
68+
{"id": 1, "name": "x", "required": True, "type": "long"},
69+
{"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"},
70+
{"id": 3, "name": "z", "required": True, "type": "long"},
71+
],
72+
},
73+
],
74+
"default-spec-id": 0,
75+
"partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}],
76+
"last-partition-id": 1000,
77+
"default-sort-order-id": 3,
78+
"sort-orders": [
79+
{
80+
"order-id": 3,
81+
"fields": [
82+
{"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"},
83+
{"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"},
84+
],
85+
}
86+
],
87+
"properties": {"read.split.target.size": "134217728"},
88+
"current-snapshot-id": initial_snapshot_id + 1999,
89+
"snapshots": snapshots,
90+
"snapshot-log": snapshot_log,
91+
"metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}],
92+
"refs": {"test": {"snapshot-id": initial_snapshot_id, "type": "tag", "max-ref-age-ms": 10000000}},
93+
}
94+
95+
return Table(
96+
identifier=("database", "table"),
97+
metadata=metadata,
98+
metadata_location=f"{metadata['location']}/uuid.metadata.json",
99+
io=load_file_io(),
100+
catalog=NoopCatalog("NoopCatalog"),
101+
)
60102

61-
return table
62103

63-
def test_expire_snapshots_removes_correct_snapshots(mock_table: Mock):
104+
105+
def test_expire_snapshots_removes_correct_snapshots(generate_test_table):
64106
"""
65107
Test case for the `ExpireSnapshots` class to ensure that the correct snapshots
66108
are removed and the delete function is called the expected number of times.
67-
68109
"""
69110

70-
with ManageSnapshots(mock_table) as transaction:
71-
# Mock the transaction to return the mock table
72-
transaction.exipre_snapshot_by_id(1).exipre_snapshot_by_id(2).expire_snapshots().cleanup_files()
73-
111+
# Use the fixture-provided table
112+
with ExpireSnapshots(generate_test_table.transaction()) as manage_snapshots:
113+
manage_snapshots.expire_snapshot_id(3051729675574597004)
74114

75-
for snapshot in mock_table.metadata.snapshots:
76-
# Verify that the snapshot is removed from the metadata
77-
assert snapshot.snapshot_id not in [1, 2]
115+
# Check the remaining snapshots
116+
remaining_snapshot_ids = {snapshot.snapshot_id for snapshot in generate_test_table.metadata.snapshots}
117+
assert not remaining_snapshot_ids.issubset({3051729675574597004})

0 commit comments

Comments
 (0)