diff --git a/dev/provision.py b/dev/provision.py index 6c8fe366d7..d831aa1560 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -389,3 +389,51 @@ VALUES (4) """ ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE {catalog_name}.default.test_table_rollback_to_snapshot_id ( + timestamp int, + number integer + ) + USING iceberg + TBLPROPERTIES ( + 'format-version'='2' + ); + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (200, 1) + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (202, 2) + """ + ) + + spark.sql( + f""" + DELETE FROM {catalog_name}.default.test_table_rollback_to_snapshot_id + WHERE number = 2 + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (204, 3) + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (206, 4) + """ + ) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8c1493974b..4db79f90fa 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -106,13 +106,14 @@ NameMapping, update_mapping, ) -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( Operation, Snapshot, SnapshotLogEntry, SnapshotSummaryCollector, Summary, + ancestor_right_before_timestamp, ancestors_of, update_snapshot_summaries, ) @@ -299,7 +300,12 @@ def __exit__(self, _: Any, value: Any, traceback: Any) -> None: """Close and commit the transaction.""" self.commit_transaction() - def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction: + def _apply( + self, + updates: Tuple[TableUpdate, ...], + requirements: Tuple[TableRequirement, ...] = (), + commit_transaction_if_autocommit: bool = True, + ) -> Transaction: """Check if the requirements are met, and applies the updates to the metadata.""" for requirement in requirements: requirement.validate(self.table_metadata) @@ -309,7 +315,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ self.table_metadata = update_table_metadata(self.table_metadata, updates) - if self._autocommit: + if self._autocommit and commit_transaction_if_autocommit: self.commit_transaction() self._updates = () self._requirements = () @@ -402,39 +408,6 @@ def set_ref_snapshot( requirements = (AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref="main"),) return self._apply(updates, requirements) - def _set_ref_snapshot( - self, - snapshot_id: int, - ref_name: str, - type: str, - max_ref_age_ms: Optional[int] = None, - max_snapshot_age_ms: Optional[int] = None, - min_snapshots_to_keep: Optional[int] = None, - ) -> UpdatesAndRequirements: - """Update a ref to a snapshot. - - Returns: - The updates and requirements for the set-snapshot-ref staged - """ - updates = ( - SetSnapshotRefUpdate( - snapshot_id=snapshot_id, - ref_name=ref_name, - type=type, - max_ref_age_ms=max_ref_age_ms, - max_snapshot_age_ms=max_snapshot_age_ms, - min_snapshots_to_keep=min_snapshots_to_keep, - ), - ) - requirements = ( - AssertRefSnapshotId( - snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None, - ref=ref_name, - ), - ) - - return updates, requirements - def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. @@ -1975,6 +1948,48 @@ def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" return self._updates, self._requirements + def _commit_if_ref_updates_exist(self) -> None: + self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False) + self._updates, self._requirements = (), () + + def _set_ref_snapshot( + self, + snapshot_id: int, + ref_name: str, + type: str, + max_ref_age_ms: Optional[int] = None, + max_snapshot_age_ms: Optional[int] = None, + min_snapshots_to_keep: Optional[int] = None, + ) -> ManageSnapshots: + """Update a ref to a snapshot. + + Stages the updates and requirements for the set-snapshot-ref + + Returns: + This for method chaining + """ + updates = ( + SetSnapshotRefUpdate( + snapshot_id=snapshot_id, + ref_name=ref_name, + type=type, + max_ref_age_ms=max_ref_age_ms, + max_snapshot_age_ms=max_snapshot_age_ms, + min_snapshots_to_keep=min_snapshots_to_keep, + ), + ) + requirements = ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id + if ref_name in self._transaction.table_metadata.refs + else None, + ref=ref_name, + ), + ) + self._updates += updates + self._requirements += requirements + return self + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -1987,15 +2002,12 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i Returns: This for method chaining """ - update, requirement = self._transaction._set_ref_snapshot( + return self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=tag_name, type="tag", max_ref_age_ms=max_ref_age_ms, ) - self._updates += update - self._requirements += requirement - return self def create_branch( self, @@ -2017,7 +2029,7 @@ def create_branch( Returns: This for method chaining """ - update, requirement = self._transaction._set_ref_snapshot( + return self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=branch_name, type="branch", @@ -2025,8 +2037,71 @@ def create_branch( max_snapshot_age_ms=max_snapshot_age_ms, min_snapshots_to_keep=min_snapshots_to_keep, ) - self._updates += update - self._requirements += requirement + + def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: + """Rollback the table to the given snapshot id. + + The snapshot needs to be an ancestor of the current table state. + + Args: + snapshot_id (int): rollback to this snapshot_id that used to be current. + Returns: + This for method chaining + """ + self._commit_if_ref_updates_exist() + if self._transaction._table.snapshot_by_id(snapshot_id) is None: + raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") + if snapshot_id not in { + ancestor.snapshot_id + for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) + }: + raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") + return self._set_ref_snapshot(snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) + + def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: + """Rollback the table to the snapshot right before the given timestamp. + + The snapshot needs to be an ancestor of the current table state. + + Args: + timestamp (int): rollback to the snapshot that used to be current right before this timestamp. + Returns: + This for method chaining + """ + self._commit_if_ref_updates_exist() + if ( + snapshot := ancestor_right_before_timestamp( + self._transaction._table.current_snapshot(), self._transaction.table_metadata, timestamp + ) + ) is None: + raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") + return self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) + + def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots: + """Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both. + + The snapshot is not required to be an ancestor of the current table state. + + Args: + snapshot_id (Optional[int]): id of the snapshot to be set as current + ref_name (Optional[str]): branch/tag where the snapshot to be set as current exists. + Returns: + This for method chaining + """ + self._commit_if_ref_updates_exist() + if (not snapshot_id or ref_name) and (snapshot_id or not ref_name): + raise ValidationError("Either snapshot_id or ref must be provided") + else: + if snapshot_id is None: + if ref_name not in self._transaction.table_metadata.refs: + raise ValidationError(f"Cannot set snapshot current to unknown ref {ref_name}") + target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id + else: + target_snapshot_id = snapshot_id + if (snapshot := self._transaction._table.snapshot_by_id(target_snapshot_id)) is None: + raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}") + + self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) return self diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 842d42522a..70923ff407 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -421,6 +421,17 @@ def set_when_positive(properties: Dict[str, str], num: int, property_name: str) properties[property_name] = str(num) +def ancestor_right_before_timestamp( + current_snapshot: Optional[Snapshot], table_metadata: TableMetadata, timestamp_ms: int +) -> Optional[Snapshot]: + """Get the ancestor right before the given timestamp.""" + if current_snapshot: + for ancestor in ancestors_of(current_snapshot, table_metadata): + if ancestor.timestamp_ms < timestamp_ms: + return ancestor + return None + + def ancestors_of(current_snapshot: Optional[Snapshot], table_metadata: TableMetadata) -> Iterable[Snapshot]: """Get the ancestors of and including the given snapshot.""" snapshot = current_snapshot diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 639193383e..d08b483b95 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -17,7 +17,8 @@ import pytest from pyiceberg.catalog import Catalog -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType +from pyiceberg.table.snapshots import ancestors_of @pytest.mark.integration @@ -28,7 +29,7 @@ def test_create_tag(catalog: Catalog) -> None: assert len(tbl.history()) > 3 tag_snapshot_id = tbl.history()[-3].snapshot_id tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name="tag123").commit() - assert tbl.metadata.refs["tag123"] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag") + assert tbl.metadata.refs["tag123"] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type=str(SnapshotRefType.TAG)) @pytest.mark.integration @@ -39,4 +40,105 @@ def test_create_branch(catalog: Catalog) -> None: assert len(tbl.history()) > 2 branch_snapshot_id = tbl.history()[-2].snapshot_id tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit() - assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs["branch123"] == SnapshotRef( + snapshot_id=branch_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_manage_snapshots_context_manager(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + expected_snapshot_id = tbl.history()[-4].snapshot_id + with tbl.manage_snapshots() as ms: + ms.create_tag(snapshot_id=current_snapshot_id, tag_name="testing") + ms.set_current_snapshot(snapshot_id=expected_snapshot_id) + ms.create_branch(snapshot_id=expected_snapshot_id, branch_name="testing2") + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore + assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + assert tbl.metadata.refs["testing2"].snapshot_id == expected_snapshot_id + + +# Maintain relative order of tests for following apis like rollback, set_current_snapshot, etc. +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_rollback_to_snapshot_id" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + rollback_snapshot_id = tbl.current_snapshot().parent_snapshot_id # type: ignore + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() # type: ignore + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=rollback_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_timestamp(catalog: Catalog) -> None: + identifier = "default.test_table_rollback_to_snapshot_id" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 4 + ancestors = list(ancestor for ancestor in ancestors_of(tbl.current_snapshot(), tbl.metadata)) # noqa + ancestor_to_rollback_to = ancestors[-1] + expected_snapshot_id, timestamp = ancestor_to_rollback_to.snapshot_id, ancestor_to_rollback_to.timestamp_ms + 1 + # not inclusive of rollback_timestamp + tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp).commit() + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_with_snapshot_id(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + expected_snapshot_id = tbl.history()[-3].snapshot_id + tbl.manage_snapshots().set_current_snapshot(snapshot_id=expected_snapshot_id).commit() + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_with_ref_name(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + expected_snapshot_id = tbl.history()[-3].snapshot_id + tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag").commit() + tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag").commit() + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + + +# Always test set_ref_snapshot last. +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_ref_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + target_snapshot_id = tbl.history()[-2].snapshot_id + tbl.manage_snapshots()._set_ref_snapshot( + snapshot_id=target_snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH) + ).commit() + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=target_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index d7c4ffeeaf..92329df7bd 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -706,30 +706,6 @@ def test_update_metadata_add_snapshot(table_v2: Table) -> None: assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms -def test_update_metadata_set_ref_snapshot(table_v2: Table) -> None: - update, _ = table_v2.transaction()._set_ref_snapshot( - snapshot_id=3051729675574597004, - ref_name="main", - type="branch", - max_ref_age_ms=123123123, - max_snapshot_age_ms=12312312312, - min_snapshots_to_keep=1, - ) - - new_metadata = update_table_metadata(table_v2.metadata, update) - assert len(new_metadata.snapshot_log) == 3 - assert new_metadata.snapshot_log[2].snapshot_id == 3051729675574597004 - assert new_metadata.current_snapshot_id == 3051729675574597004 - assert new_metadata.last_updated_ms > table_v2.metadata.last_updated_ms - assert new_metadata.refs["main"] == SnapshotRef( - snapshot_id=3051729675574597004, - snapshot_ref_type="branch", - min_snapshots_to_keep=1, - max_snapshot_age_ms=12312312312, - max_ref_age_ms=123123123, - ) - - def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None: update = SetSnapshotRefUpdate( ref_name="main",