Skip to content

Commit b7bdb6c

Browse files
committed
add public and private APIs, register RemoveSnapshotRefUpdate with apply metadata fn
1 parent a8d3f17 commit b7bdb6c

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

pyiceberg/table/__init__.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,24 @@ def _set_ref_snapshot(
435435

436436
return updates, requirements
437437

438+
def _remove_ref_snapshot(self, ref_name: str) -> UpdatesAndRequirements:
439+
"""Remove a snapshot ref.
440+
441+
Args:
442+
ref_name: branch / tag name to remove
443+
444+
Returns
445+
The updates and requirements for the remove-snapshot-ref.
446+
"""
447+
updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),)
448+
requirements = (
449+
AssertRefSnapshotId(
450+
snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None,
451+
ref=ref_name,
452+
),
453+
)
454+
return updates, requirements
455+
438456
def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
439457
"""Create a new UpdateSchema to alter the columns of this table.
440458
@@ -1023,6 +1041,23 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
10231041
return base_metadata.model_copy(update=metadata_updates)
10241042

10251043

1044+
@_apply_table_update.register(RemoveSnapshotRefUpdate)
1045+
def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
1046+
if (existing_ref := base_metadata.refs.get(update.ref_name)) is None:
1047+
return base_metadata
1048+
1049+
if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None:
1050+
raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}")
1051+
1052+
if update.ref_name == MAIN_BRANCH:
1053+
raise ValueError("Cannot remove main branch")
1054+
1055+
metadata_refs = {**base_metadata.refs}
1056+
metadata_refs.pop(update.ref_name, None)
1057+
context.add_update(update)
1058+
return base_metadata.model_copy(update={"refs": metadata_refs})
1059+
1060+
10261061
@_apply_table_update.register(AddSortOrderUpdate)
10271062
def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
10281063
context.add_update(update)
@@ -1997,6 +2032,21 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i
19972032
self._requirements += requirement
19982033
return self
19992034

2035+
def remove_tag(self, tag_name: str) -> ManageSnapshots:
2036+
"""
2037+
Remove a tag.
2038+
2039+
Args:
2040+
tag_name (str): name of tag to remove
2041+
2042+
Returns:
2043+
This for method chaining
2044+
"""
2045+
update, requirement = self._transaction._remove_ref_snapshot(ref_name=tag_name)
2046+
self._updates += update
2047+
self._requirements += requirement
2048+
return self
2049+
20002050
def create_branch(
20012051
self,
20022052
snapshot_id: int,
@@ -2029,6 +2079,20 @@ def create_branch(
20292079
self._requirements += requirement
20302080
return self
20312081

2082+
def remove_branch(self, branch_name: str) -> ManageSnapshots:
2083+
"""
2084+
Remove a branch.
2085+
2086+
Args:
2087+
branch_name (str): name of branch to remove
2088+
Returns:
2089+
This for method chaining
2090+
"""
2091+
update, requirement = self._transaction._remove_ref_snapshot(ref_name=branch_name)
2092+
self._updates += update
2093+
self._requirements += requirement
2094+
return self
2095+
20322096

20332097
class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
20342098
_schema: Schema

tests/integration/test_snapshot_operations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,35 @@ def test_create_branch(catalog: Catalog) -> None:
4040
branch_snapshot_id = tbl.history()[-2].snapshot_id
4141
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit()
4242
assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
43+
44+
45+
@pytest.mark.integration
46+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
47+
def test_remove_tag(catalog: Catalog) -> None:
48+
identifier = "default.test_table_snapshot_operations"
49+
tbl = catalog.load_table(identifier)
50+
assert len(tbl.history()) > 3
51+
# first, create the tag to remove
52+
tag_name = "tag_to_remove"
53+
tag_snapshot_id = tbl.history()[-3].snapshot_id
54+
tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name=tag_name).commit()
55+
assert tbl.metadata.refs[tag_name] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag")
56+
# now, remove the tag
57+
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
58+
assert tbl.metadata.refs.get(tag_name, None) is None
59+
60+
61+
@pytest.mark.integration
62+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
63+
def test_remove_branch(catalog: Catalog) -> None:
64+
identifier = "default.test_table_snapshot_operations"
65+
tbl = catalog.load_table(identifier)
66+
assert len(tbl.history()) > 2
67+
# first, create the branch to remove
68+
branch_name = "branch_to_remove"
69+
branch_snapshot_id = tbl.history()[-2].snapshot_id
70+
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name=branch_name).commit()
71+
assert tbl.metadata.refs[branch_name] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
72+
# now, remove the branch
73+
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
74+
assert tbl.metadata.refs.get(branch_name, None) is None

0 commit comments

Comments
 (0)