Skip to content

Commit 7093ec1

Browse files
committed
cherry-pick: Support Remove Branch or Tag APIs (apache#822)
1 parent 29c9a96 commit 7093ec1

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

pyiceberg/table/update/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,23 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
466466
return base_metadata.model_copy(update=metadata_updates)
467467

468468

469+
@_apply_table_update.register(RemoveSnapshotRefUpdate)
470+
def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
471+
if (existing_ref := base_metadata.refs.get(update.ref_name, None)) is None:
472+
return base_metadata
473+
474+
if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None:
475+
raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}")
476+
477+
if update.ref_name == MAIN_BRANCH:
478+
raise ValueError("Cannot remove main branch")
479+
480+
metadata_refs = {**base_metadata.refs}
481+
metadata_refs.pop(update.ref_name, None)
482+
context.add_update(update)
483+
return base_metadata.model_copy(update={"refs": metadata_refs})
484+
485+
469486
@_apply_table_update.register(AddSortOrderUpdate)
470487
def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
471488
context.add_update(update)

pyiceberg/table/update/snapshot.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pyiceberg.table.update import (
6666
AddSnapshotUpdate,
6767
AssertRefSnapshotId,
68+
RemoveSnapshotRefUpdate,
6869
SetSnapshotRefUpdate,
6970
TableRequirement,
7071
TableUpdate,
@@ -749,6 +750,28 @@ def _commit(self) -> UpdatesAndRequirements:
749750
"""Apply the pending changes and commit."""
750751
return self._updates, self._requirements
751752

753+
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
754+
"""Remove a snapshot ref.
755+
756+
Args:
757+
ref_name: branch / tag name to remove
758+
Stages the updates and requirements for the remove-snapshot-ref.
759+
Returns
760+
This method for chaining
761+
"""
762+
updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),)
763+
requirements = (
764+
AssertRefSnapshotId(
765+
snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id
766+
if ref_name in self._transaction.table_metadata.refs
767+
else None,
768+
ref=ref_name,
769+
),
770+
)
771+
self._updates += updates
772+
self._requirements += requirements
773+
return self
774+
752775
def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots:
753776
"""
754777
Create a new tag pointing to the given snapshot id.
@@ -771,6 +794,17 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i
771794
self._requirements += requirement
772795
return self
773796

797+
def remove_tag(self, tag_name: str) -> ManageSnapshots:
798+
"""
799+
Remove a tag.
800+
801+
Args:
802+
tag_name (str): name of tag to remove
803+
Returns:
804+
This for method chaining
805+
"""
806+
return self._remove_ref_snapshot(ref_name=tag_name)
807+
774808
def create_branch(
775809
self,
776810
snapshot_id: int,
@@ -802,3 +836,14 @@ def create_branch(
802836
self._updates += update
803837
self._requirements += requirement
804838
return self
839+
840+
def remove_branch(self, branch_name: str) -> ManageSnapshots:
841+
"""
842+
Remove a branch.
843+
844+
Args:
845+
branch_name (str): name of branch to remove
846+
Returns:
847+
This for method chaining
848+
"""
849+
return self._remove_ref_snapshot(ref_name=branch_name)

tests/integration/test_snapshot_operations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,35 @@ def test_create_branch(catalog: Catalog) -> None:
4040
branch_snapshot_id = tbl.history()[-2].snapshot_id
4141
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit()
4242
assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
43+
44+
45+
@pytest.mark.integration
46+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
47+
def test_remove_tag(catalog: Catalog) -> None:
48+
identifier = "default.test_table_snapshot_operations"
49+
tbl = catalog.load_table(identifier)
50+
assert len(tbl.history()) > 3
51+
# first, create the tag to remove
52+
tag_name = "tag_to_remove"
53+
tag_snapshot_id = tbl.history()[-3].snapshot_id
54+
tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name=tag_name).commit()
55+
assert tbl.metadata.refs[tag_name] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag")
56+
# now, remove the tag
57+
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
58+
assert tbl.metadata.refs.get(tag_name, None) is None
59+
60+
61+
@pytest.mark.integration
62+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
63+
def test_remove_branch(catalog: Catalog) -> None:
64+
identifier = "default.test_table_snapshot_operations"
65+
tbl = catalog.load_table(identifier)
66+
assert len(tbl.history()) > 2
67+
# first, create the branch to remove
68+
branch_name = "branch_to_remove"
69+
branch_snapshot_id = tbl.history()[-2].snapshot_id
70+
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name=branch_name).commit()
71+
assert tbl.metadata.refs[branch_name] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
72+
# now, remove the branch
73+
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
74+
assert tbl.metadata.refs.get(branch_name, None) is None

0 commit comments

Comments
 (0)