diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 6c6da2a9b7..f6e54a3370 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -432,7 +432,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive name_mapping=self.table_metadata.name_mapping(), ) - def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot: + def update_snapshot( + self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False + ) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: @@ -441,7 +443,9 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, bran if branch is None: branch = MAIN_BRANCH - return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) + return UpdateSnapshot( + self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only + ) def update_statistics(self) -> UpdateStatistics: """ diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 3ffb275ded..a9322578be 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -109,6 +109,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _deleted_data_files: Set[DataFile] _compression: AvroCompressionCodec _target_branch = MAIN_BRANCH + _stage_only = False def __init__( self, @@ -118,6 +119,7 @@ def __init__( commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH, + stage_only: bool = False, ) -> None: super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() @@ -137,6 +139,7 @@ def __init__( self._parent_snapshot_id = ( snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None ) + self._stage_only = stage_only def _validate_target_branch(self, branch: str) -> str: # Default is already set to MAIN_BRANCH. So branch name can't be None. @@ -292,25 +295,33 @@ def _commit(self) -> UpdatesAndRequirements: schema_id=self._transaction.table_metadata.current_schema_id, ) - return ( - ( - AddSnapshotUpdate(snapshot=snapshot), - SetSnapshotRefUpdate( - snapshot_id=self._snapshot_id, - parent_snapshot_id=self._parent_snapshot_id, - ref_name=self._target_branch, - type=SnapshotRefType.BRANCH, + add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot) + + if self._stage_only: + return ( + (add_snapshot_update,), + (), + ) + else: + return ( + ( + add_snapshot_update, + SetSnapshotRefUpdate( + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + ref_name=self._target_branch, + type=SnapshotRefType.BRANCH, + ), ), - ), - ( - AssertRefSnapshotId( - snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id - if self._target_branch in self._transaction.table_metadata.refs - else None, - ref=self._target_branch, + ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id + if self._target_branch in self._transaction.table_metadata.refs + else None, + ref=self._target_branch, + ), ), - ), - ) + ) @property def snapshot_id(self) -> int: @@ -360,8 +371,9 @@ def __init__( branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, + stage_only: bool = False, ): - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only) self._predicate = AlwaysFalse() self._case_sensitive = True @@ -530,10 +542,11 @@ def __init__( branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, + stage_only: bool = False, ) -> None: from pyiceberg.table import TableProperties - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only) self._target_size_bytes = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_TARGET_SIZE_BYTES, @@ -649,6 +662,7 @@ class UpdateSnapshot: _transaction: Transaction _io: FileIO _branch: str + _stage_only: bool _snapshot_properties: Dict[str, str] def __init__( @@ -656,12 +670,14 @@ def __init__( transaction: Transaction, io: FileIO, branch: str, + stage_only: bool = False, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: self._transaction = transaction self._io = io self._snapshot_properties = snapshot_properties self._branch = branch + self._stage_only = stage_only def fast_append(self) -> _FastAppendFiles: return _FastAppendFiles( @@ -670,6 +686,7 @@ def fast_append(self) -> _FastAppendFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) def merge_append(self) -> _MergeAppendFiles: @@ -679,6 +696,7 @@ def merge_append(self) -> _MergeAppendFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: @@ -691,6 +709,7 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) def delete(self) -> _DeleteFiles: @@ -700,6 +719,7 @@ def delete(self) -> _DeleteFiles: io=self._io, branch=self._branch, snapshot_properties=self._snapshot_properties, + stage_only=self._stage_only, ) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index e63883c1db..8d9026cd5a 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -2098,3 +2098,174 @@ def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSessio ) assert main_df.count() == 3 assert branch_df.count() == 2 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_delete_files_v{format_version}" + iceberg_spec = PartitionSpec( + *[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")] + ) + tbl = _create_table( + session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec + ) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + files_to_delete = [] + for file_task in tbl.scan().plan_files(): + files_to_delete.append(file_task.file) + assert len(files_to_delete) > 0 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).delete() as delete: + delete.delete_by_predicate(EqualTo("int", 9)) + + # a new delete snapshot is added + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "delete"] + + # snapshot main ref has not changed + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_fast_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_fast_append_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).fast_append() as fast_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + fast_append.append_data_file(data_file=data_file) + + # Main ref has not changed and data is not yet appended + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + # There should be a new staged snapshot + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "append"] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_merge_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_merge_append_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).merge_append() as merge_append: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + merge_append.append_data_file(data_file=data_file) + + # Main ref has not changed and data is not yet appended + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + # There should be a new staged snapshot + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "append"] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_overwrite_files( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_overwrite_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + files_to_delete = [] + for file_task in tbl.scan().plan_files(): + files_to_delete.append(file_task.file) + assert len(files_to_delete) > 0 + + with tbl.transaction() as txn: + with txn.update_snapshot(stage_only=True).overwrite() as overwrite: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io + ): + overwrite.append_data_file(data_file=data_file) + overwrite.delete_data_file(files_to_delete[0]) + + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + assert operations == ["append", "overwrite"]