Skip to content

Commit 7181ae1

Browse files
author
Yingjian Wu
committed
implement stageOnly Commit
1 parent 8052652 commit 7181ae1

File tree

3 files changed

+215
-21
lines changed

3 files changed

+215
-21
lines changed

pyiceberg/table/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
430430
name_mapping=self.table_metadata.name_mapping(),
431431
)
432432

433-
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
433+
def update_snapshot(
434+
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False
435+
) -> UpdateSnapshot:
434436
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
435437
436438
Returns:
@@ -439,7 +441,9 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, bran
439441
if branch is None:
440442
branch = MAIN_BRANCH
441443

442-
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
444+
return UpdateSnapshot(
445+
self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only
446+
)
443447

444448
def update_statistics(self) -> UpdateStatistics:
445449
"""

pyiceberg/table/update/snapshot.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
109109
_deleted_data_files: Set[DataFile]
110110
_compression: AvroCompressionCodec
111111
_target_branch = MAIN_BRANCH
112+
_stage_only = False
112113

113114
def __init__(
114115
self,
@@ -118,6 +119,7 @@ def __init__(
118119
commit_uuid: Optional[uuid.UUID] = None,
119120
snapshot_properties: Dict[str, str] = EMPTY_DICT,
120121
branch: str = MAIN_BRANCH,
122+
stage_only: bool = False,
121123
) -> None:
122124
super().__init__(transaction)
123125
self.commit_uuid = commit_uuid or uuid.uuid4()
@@ -137,6 +139,7 @@ def __init__(
137139
self._parent_snapshot_id = (
138140
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
139141
)
142+
self._stage_only = stage_only
140143

141144
def _validate_target_branch(self, branch: str) -> str:
142145
# Default is already set to MAIN_BRANCH. So branch name can't be None.
@@ -292,25 +295,33 @@ def _commit(self) -> UpdatesAndRequirements:
292295
schema_id=self._transaction.table_metadata.current_schema_id,
293296
)
294297

295-
return (
296-
(
297-
AddSnapshotUpdate(snapshot=snapshot),
298-
SetSnapshotRefUpdate(
299-
snapshot_id=self._snapshot_id,
300-
parent_snapshot_id=self._parent_snapshot_id,
301-
ref_name=self._target_branch,
302-
type=SnapshotRefType.BRANCH,
298+
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)
299+
300+
if self._stage_only:
301+
return (
302+
(add_snapshot_update,),
303+
(),
304+
)
305+
else:
306+
return (
307+
(
308+
add_snapshot_update,
309+
SetSnapshotRefUpdate(
310+
snapshot_id=self._snapshot_id,
311+
parent_snapshot_id=self._parent_snapshot_id,
312+
ref_name=self._target_branch,
313+
type=SnapshotRefType.BRANCH,
314+
),
303315
),
304-
),
305-
(
306-
AssertRefSnapshotId(
307-
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
308-
if self._target_branch in self._transaction.table_metadata.refs
309-
else None,
310-
ref=self._target_branch,
316+
(
317+
AssertRefSnapshotId(
318+
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
319+
if self._target_branch in self._transaction.table_metadata.refs
320+
else None,
321+
ref=self._target_branch,
322+
),
311323
),
312-
),
313-
)
324+
)
314325

315326
@property
316327
def snapshot_id(self) -> int:
@@ -360,8 +371,9 @@ def __init__(
360371
branch: str,
361372
commit_uuid: Optional[uuid.UUID] = None,
362373
snapshot_properties: Dict[str, str] = EMPTY_DICT,
374+
stage_only: bool = False,
363375
):
364-
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
376+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
365377
self._predicate = AlwaysFalse()
366378
self._case_sensitive = True
367379

@@ -530,10 +542,11 @@ def __init__(
530542
branch: str,
531543
commit_uuid: Optional[uuid.UUID] = None,
532544
snapshot_properties: Dict[str, str] = EMPTY_DICT,
545+
stage_only: bool = False,
533546
) -> None:
534547
from pyiceberg.table import TableProperties
535548

536-
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
549+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
537550
self._target_size_bytes = property_as_int(
538551
self._transaction.table_metadata.properties,
539552
TableProperties.MANIFEST_TARGET_SIZE_BYTES,
@@ -649,19 +662,22 @@ class UpdateSnapshot:
649662
_transaction: Transaction
650663
_io: FileIO
651664
_branch: str
665+
_stage_only: bool
652666
_snapshot_properties: Dict[str, str]
653667

654668
def __init__(
655669
self,
656670
transaction: Transaction,
657671
io: FileIO,
658672
branch: str,
673+
stage_only: bool = False,
659674
snapshot_properties: Dict[str, str] = EMPTY_DICT,
660675
) -> None:
661676
self._transaction = transaction
662677
self._io = io
663678
self._snapshot_properties = snapshot_properties
664679
self._branch = branch
680+
self._stage_only = stage_only
665681

666682
def fast_append(self) -> _FastAppendFiles:
667683
return _FastAppendFiles(
@@ -670,6 +686,7 @@ def fast_append(self) -> _FastAppendFiles:
670686
io=self._io,
671687
branch=self._branch,
672688
snapshot_properties=self._snapshot_properties,
689+
stage_only=self._stage_only,
673690
)
674691

675692
def merge_append(self) -> _MergeAppendFiles:
@@ -679,6 +696,7 @@ def merge_append(self) -> _MergeAppendFiles:
679696
io=self._io,
680697
branch=self._branch,
681698
snapshot_properties=self._snapshot_properties,
699+
stage_only=self._stage_only,
682700
)
683701

684702
def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
@@ -691,6 +709,7 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
691709
io=self._io,
692710
branch=self._branch,
693711
snapshot_properties=self._snapshot_properties,
712+
stage_only=self._stage_only,
694713
)
695714

696715
def delete(self) -> _DeleteFiles:
@@ -700,6 +719,7 @@ def delete(self) -> _DeleteFiles:
700719
io=self._io,
701720
branch=self._branch,
702721
snapshot_properties=self._snapshot_properties,
722+
stage_only=self._stage_only,
703723
)
704724

705725

tests/integration/test_writes/test_writes.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,3 +2261,173 @@ def test_nanosecond_support_on_catalog(session_catalog: Catalog) -> None:
22612261
)
22622262

22632263
_create_table(session_catalog, identifier, {"format-version": "3"}, schema=table.schema)
2264+
2265+
2266+
@pytest.mark.parametrize("format_version", [1, 2])
2267+
def test_stage_only_delete(
2268+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2269+
) -> None:
2270+
identifier = f"default.test_stage_only_delete_files_v{format_version}"
2271+
iceberg_spec = PartitionSpec(
2272+
*[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")]
2273+
)
2274+
tbl = _create_table(
2275+
session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec
2276+
)
2277+
2278+
current_snapshot = tbl.metadata.current_snapshot_id
2279+
assert current_snapshot is not None
2280+
2281+
original_count = len(tbl.scan().to_arrow())
2282+
assert original_count == 3
2283+
2284+
files_to_delete = []
2285+
for file_task in tbl.scan().plan_files():
2286+
files_to_delete.append(file_task.file)
2287+
assert len(files_to_delete) > 0
2288+
2289+
with tbl.transaction() as txn:
2290+
with txn.update_snapshot(stage_only=True).delete() as delete:
2291+
delete.delete_by_predicate(EqualTo("int", 9))
2292+
2293+
# a new delete snapshot is added
2294+
snapshots = tbl.snapshots()
2295+
assert len(snapshots) == 2
2296+
2297+
rows = spark.sql(
2298+
f"""
2299+
SELECT operation, summary
2300+
FROM {identifier}.snapshots
2301+
ORDER BY committed_at ASC
2302+
"""
2303+
).collect()
2304+
operations = [row.operation for row in rows]
2305+
assert operations == ["append", "delete"]
2306+
2307+
# snapshot main ref has not changed
2308+
assert current_snapshot == tbl.metadata.current_snapshot_id
2309+
assert len(tbl.scan().to_arrow()) == original_count
2310+
2311+
2312+
@pytest.mark.integration
2313+
@pytest.mark.parametrize("format_version", [1, 2])
2314+
def test_stage_only_fast_append(
2315+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2316+
) -> None:
2317+
identifier = f"default.test_stage_only_fast_append_files_v{format_version}"
2318+
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2319+
2320+
current_snapshot = tbl.metadata.current_snapshot_id
2321+
assert current_snapshot is not None
2322+
2323+
original_count = len(tbl.scan().to_arrow())
2324+
assert original_count == 3
2325+
2326+
with tbl.transaction() as txn:
2327+
with txn.update_snapshot(stage_only=True).fast_append() as fast_append:
2328+
for data_file in _dataframe_to_data_files(
2329+
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2330+
):
2331+
fast_append.append_data_file(data_file=data_file)
2332+
2333+
# Main ref has not changed and data is not yet appended
2334+
assert current_snapshot == tbl.metadata.current_snapshot_id
2335+
assert len(tbl.scan().to_arrow()) == original_count
2336+
2337+
# There should be a new staged snapshot
2338+
snapshots = tbl.snapshots()
2339+
assert len(snapshots) == 2
2340+
2341+
rows = spark.sql(
2342+
f"""
2343+
SELECT operation, summary
2344+
FROM {identifier}.snapshots
2345+
ORDER BY committed_at ASC
2346+
"""
2347+
).collect()
2348+
operations = [row.operation for row in rows]
2349+
assert operations == ["append", "append"]
2350+
2351+
2352+
@pytest.mark.integration
2353+
@pytest.mark.parametrize("format_version", [1, 2])
2354+
def test_stage_only_merge_append(
2355+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2356+
) -> None:
2357+
identifier = f"default.test_stage_only_merge_append_files_v{format_version}"
2358+
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2359+
2360+
current_snapshot = tbl.metadata.current_snapshot_id
2361+
assert current_snapshot is not None
2362+
2363+
original_count = len(tbl.scan().to_arrow())
2364+
assert original_count == 3
2365+
2366+
with tbl.transaction() as txn:
2367+
with txn.update_snapshot(stage_only=True).merge_append() as merge_append:
2368+
for data_file in _dataframe_to_data_files(
2369+
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2370+
):
2371+
merge_append.append_data_file(data_file=data_file)
2372+
2373+
# Main ref has not changed and data is not yet appended
2374+
assert current_snapshot == tbl.metadata.current_snapshot_id
2375+
assert len(tbl.scan().to_arrow()) == original_count
2376+
2377+
# There should be a new staged snapshot
2378+
snapshots = tbl.snapshots()
2379+
assert len(snapshots) == 2
2380+
2381+
rows = spark.sql(
2382+
f"""
2383+
SELECT operation, summary
2384+
FROM {identifier}.snapshots
2385+
ORDER BY committed_at ASC
2386+
"""
2387+
).collect()
2388+
operations = [row.operation for row in rows]
2389+
assert operations == ["append", "append"]
2390+
2391+
2392+
@pytest.mark.integration
2393+
@pytest.mark.parametrize("format_version", [1, 2])
2394+
def test_stage_only_overwrite_files(
2395+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2396+
) -> None:
2397+
identifier = f"default.test_stage_only_overwrite_files_v{format_version}"
2398+
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2399+
2400+
current_snapshot = tbl.metadata.current_snapshot_id
2401+
assert current_snapshot is not None
2402+
2403+
original_count = len(tbl.scan().to_arrow())
2404+
assert original_count == 3
2405+
2406+
files_to_delete = []
2407+
for file_task in tbl.scan().plan_files():
2408+
files_to_delete.append(file_task.file)
2409+
assert len(files_to_delete) > 0
2410+
2411+
with tbl.transaction() as txn:
2412+
with txn.update_snapshot(stage_only=True).overwrite() as overwrite:
2413+
for data_file in _dataframe_to_data_files(
2414+
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2415+
):
2416+
overwrite.append_data_file(data_file=data_file)
2417+
overwrite.delete_data_file(files_to_delete[0])
2418+
2419+
assert current_snapshot == tbl.metadata.current_snapshot_id
2420+
assert len(tbl.scan().to_arrow()) == original_count
2421+
2422+
snapshots = tbl.snapshots()
2423+
assert len(snapshots) == 2
2424+
2425+
rows = spark.sql(
2426+
f"""
2427+
SELECT operation, summary
2428+
FROM {identifier}.snapshots
2429+
ORDER BY committed_at ASC
2430+
"""
2431+
).collect()
2432+
operations = [row.operation for row in rows]
2433+
assert operations == ["append", "overwrite"]

0 commit comments

Comments
 (0)