Skip to content

Commit 27c3ece

Browse files
committed
Removed: unrelated changes, Added: logic to expire snapshot method.
Implemented logic to protect the HEAD branches or Tagged branches from being expired by the `expire_snapshot_by_id` method.
1 parent 2c3153e commit 27c3ece

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

pyiceberg/table/update/snapshot.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from pyiceberg.partitioning import (
5656
PartitionSpec,
5757
)
58+
from pyiceberg.table.refs import SnapshotRefType
5859
from pyiceberg.table.snapshots import (
5960
Operation,
6061
Snapshot,
@@ -857,7 +858,7 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
857858
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
858859
Pending changes are applied on commit.
859860
"""
860-
861+
861862
_snapshot_ids_to_expire = set()
862863
_updates: Tuple[TableUpdate, ...] = ()
863864
_requirements: Tuple[TableRequirement, ...] = ()
@@ -875,6 +876,21 @@ def _commit(self) -> UpdatesAndRequirements:
875876
self._updates += (update,)
876877
return self._updates, self._requirements
877878

879+
def _get_protected_snapshot_ids(self):
880+
"""
881+
Get the IDs of protected snapshots. These are the HEAD snapshots of all branches
882+
and all tagged snapshots. These ids are to be excluded from expiration.
883+
Returns:
884+
Set of protected snapshot IDs to exclude from expiration.
885+
"""
886+
protected_ids = set()
887+
888+
for ref in self._transaction.table_metadata.refs.values():
889+
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
890+
protected_ids.add(ref.snapshot_id)
891+
892+
return protected_ids
893+
878894
def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
879895
"""
880896
Expire a snapshot by its ID.
@@ -885,7 +901,13 @@ def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
885901
Returns:
886902
This for method chaining.
887903
"""
904+
888905
if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None:
889906
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
907+
908+
if snapshot_id in self._get_protected_snapshot_ids():
909+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
910+
890911
self._snapshot_ids_to_expire.add(snapshot_id)
891-
return self
912+
913+
return self

tests/expressions/test_literals.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,6 @@ def test_invalid_decimal_conversions() -> None:
744744
def test_invalid_string_conversions() -> None:
745745
assert_invalid_conversions(
746746
literal("abc"),
747-
[FixedType(1), BinaryType()],
748747
)
749748

750749

tests/integration/test_partition_evolution.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,6 @@ def test_add_hour(catalog: Catalog) -> None:
140140
_validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "hour_transform"))
141141

142142

143-
@pytest.mark.integration
144-
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
145-
def test_add_hour_string_transform(catalog: Catalog) -> None:
146-
table = _table(catalog)
147-
table.update_spec().add_field("event_ts", "hour", "str_hour_transform").commit()
148-
_validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "str_hour_transform"))
149-
150143

151144
@pytest.mark.integration
152145
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])

0 commit comments

Comments
 (0)