Skip to content

Commit 08dee72

Browse files
author
Yingjian Wu
committed
implement stageOnly Commit
1 parent 4f02298 commit 08dee72

File tree

6 files changed

+323
-47
lines changed

6 files changed

+323
-47
lines changed

pyiceberg/table/__init__.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
398398
expr = Or(expr, match_partition_expression)
399399
return expr
400400

401-
def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles:
401+
def _append_snapshot_producer(
402+
self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH
403+
) -> _FastAppendFiles:
402404
"""Determine the append type based on table properties.
403405
404406
Args:
@@ -439,9 +441,6 @@ def update_snapshot(
439441
Returns:
440442
A new UpdateSnapshot
441443
"""
442-
if branch is None:
443-
branch = MAIN_BRANCH
444-
445444
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
446445

447446
def update_statistics(self) -> UpdateStatistics:
@@ -453,7 +452,7 @@ def update_statistics(self) -> UpdateStatistics:
453452
"""
454453
return UpdateStatistics(transaction=self)
455454

456-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
455+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
457456
"""
458457
Shorthand API for appending a PyArrow table to a table transaction.
459458
@@ -492,7 +491,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
492491
append_files.append_data_file(data_file)
493492

494493
def dynamic_partition_overwrite(
495-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
494+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
496495
) -> None:
497496
"""
498497
Shorthand for overwriting existing partitions with a PyArrow table.
@@ -559,7 +558,7 @@ def overwrite(
559558
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
560559
snapshot_properties: Dict[str, str] = EMPTY_DICT,
561560
case_sensitive: bool = True,
562-
branch: Optional[str] = None,
561+
branch: Optional[str] = MAIN_BRANCH,
563562
) -> None:
564563
"""
565564
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
@@ -619,7 +618,7 @@ def delete(
619618
delete_filter: Union[str, BooleanExpression],
620619
snapshot_properties: Dict[str, str] = EMPTY_DICT,
621620
case_sensitive: bool = True,
622-
branch: Optional[str] = None,
621+
branch: Optional[str] = MAIN_BRANCH,
623622
) -> None:
624623
"""
625624
Shorthand for deleting record from a table.
@@ -722,7 +721,7 @@ def upsert(
722721
when_matched_update_all: bool = True,
723722
when_not_matched_insert_all: bool = True,
724723
case_sensitive: bool = True,
725-
branch: Optional[str] = None,
724+
branch: Optional[str] = MAIN_BRANCH,
726725
) -> UpsertResult:
727726
"""Shorthand API for performing an upsert to an iceberg table.
728727
@@ -807,7 +806,7 @@ def upsert(
807806
case_sensitive=case_sensitive,
808807
)
809808

810-
if branch is not None:
809+
if branch in self.table_metadata.refs:
811810
matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch)
812811

813812
matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader()
@@ -1303,7 +1302,7 @@ def upsert(
13031302
when_matched_update_all: bool = True,
13041303
when_not_matched_insert_all: bool = True,
13051304
case_sensitive: bool = True,
1306-
branch: Optional[str] = None,
1305+
branch: Optional[str] = MAIN_BRANCH,
13071306
) -> UpsertResult:
13081307
"""Shorthand API for performing an upsert to an iceberg table.
13091308
@@ -1350,7 +1349,7 @@ def upsert(
13501349
branch=branch,
13511350
)
13521351

1353-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
1352+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
13541353
"""
13551354
Shorthand API for appending a PyArrow table to the table.
13561355
@@ -1363,7 +1362,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
13631362
tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch)
13641363

13651364
def dynamic_partition_overwrite(
1366-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
1365+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
13671366
) -> None:
13681367
"""Shorthand for dynamic overwriting the table with a PyArrow table.
13691368
@@ -1382,7 +1381,7 @@ def overwrite(
13821381
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
13831382
snapshot_properties: Dict[str, str] = EMPTY_DICT,
13841383
case_sensitive: bool = True,
1385-
branch: Optional[str] = None,
1384+
branch: Optional[str] = MAIN_BRANCH,
13861385
) -> None:
13871386
"""
13881387
Shorthand for overwriting the table with a PyArrow table.
@@ -1415,7 +1414,7 @@ def delete(
14151414
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
14161415
snapshot_properties: Dict[str, str] = EMPTY_DICT,
14171416
case_sensitive: bool = True,
1418-
branch: Optional[str] = None,
1417+
branch: Optional[str] = MAIN_BRANCH,
14191418
) -> None:
14201419
"""
14211420
Shorthand for deleting rows from the table.

pyiceberg/table/metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,10 @@ def new_snapshot_id(self) -> int:
295295

296296
return snapshot_id
297297

298-
def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
298+
def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]:
299299
"""Return the snapshot referenced by the given name or null if no such reference exists."""
300+
if name is None:
301+
name = MAIN_BRANCH
300302
if ref := self.refs.get(name):
301303
return self.snapshot_by_id(ref.snapshot_id)
302304
return None

pyiceberg/table/update/snapshot.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
110110
_manifest_num_counter: itertools.count[int]
111111
_deleted_data_files: Set[DataFile]
112112
_compression: AvroCompressionCodec
113-
_target_branch = MAIN_BRANCH
113+
_target_branch: Optional[str]
114114

115115
def __init__(
116116
self,
@@ -119,7 +119,7 @@ def __init__(
119119
io: FileIO,
120120
commit_uuid: Optional[uuid.UUID] = None,
121121
snapshot_properties: Dict[str, str] = EMPTY_DICT,
122-
branch: str = MAIN_BRANCH,
122+
branch: Optional[str] = MAIN_BRANCH,
123123
) -> None:
124124
super().__init__(transaction)
125125
self.commit_uuid = commit_uuid or uuid.uuid4()
@@ -140,14 +140,13 @@ def __init__(
140140
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
141141
)
142142

143-
def _validate_target_branch(self, branch: str) -> str:
144-
# Default is already set to MAIN_BRANCH. So branch name can't be None.
145-
if branch is None:
146-
raise ValueError("Invalid branch name: null")
147-
if branch in self._transaction.table_metadata.refs:
148-
ref = self._transaction.table_metadata.refs[branch]
149-
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
150-
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
143+
def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]:
144+
# if branch is none, write will be written into a staging snapshot
145+
if branch is not None:
146+
if branch in self._transaction.table_metadata.refs:
147+
ref = self._transaction.table_metadata.refs[branch]
148+
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
149+
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
151150
return branch
152151

153152
def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
@@ -294,25 +293,33 @@ def _commit(self) -> UpdatesAndRequirements:
294293
schema_id=self._transaction.table_metadata.current_schema_id,
295294
)
296295

297-
return (
298-
(
299-
AddSnapshotUpdate(snapshot=snapshot),
300-
SetSnapshotRefUpdate(
301-
snapshot_id=self._snapshot_id,
302-
parent_snapshot_id=self._parent_snapshot_id,
303-
ref_name=self._target_branch,
304-
type=SnapshotRefType.BRANCH,
296+
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)
297+
298+
if self._target_branch is None:
299+
return (
300+
(add_snapshot_update,),
301+
(),
302+
)
303+
else:
304+
return (
305+
(
306+
add_snapshot_update,
307+
SetSnapshotRefUpdate(
308+
snapshot_id=self._snapshot_id,
309+
parent_snapshot_id=self._parent_snapshot_id,
310+
ref_name=self._target_branch,
311+
type=SnapshotRefType.BRANCH,
312+
),
305313
),
306-
),
307-
(
308-
AssertRefSnapshotId(
309-
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
310-
if self._target_branch in self._transaction.table_metadata.refs
311-
else None,
312-
ref=self._target_branch,
314+
(
315+
AssertRefSnapshotId(
316+
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
317+
if self._target_branch in self._transaction.table_metadata.refs
318+
else None,
319+
ref=self._target_branch,
320+
),
313321
),
314-
),
315-
)
322+
)
316323

317324
@property
318325
def snapshot_id(self) -> int:
@@ -359,7 +366,7 @@ def __init__(
359366
operation: Operation,
360367
transaction: Transaction,
361368
io: FileIO,
362-
branch: str,
369+
branch: Optional[str] = MAIN_BRANCH,
363370
commit_uuid: Optional[uuid.UUID] = None,
364371
snapshot_properties: Dict[str, str] = EMPTY_DICT,
365372
):
@@ -530,7 +537,7 @@ def __init__(
530537
operation: Operation,
531538
transaction: Transaction,
532539
io: FileIO,
533-
branch: str,
540+
branch: Optional[str] = MAIN_BRANCH,
534541
commit_uuid: Optional[uuid.UUID] = None,
535542
snapshot_properties: Dict[str, str] = EMPTY_DICT,
536543
) -> None:
@@ -651,14 +658,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
651658
class UpdateSnapshot:
652659
_transaction: Transaction
653660
_io: FileIO
654-
_branch: str
661+
_branch: Optional[str]
655662
_snapshot_properties: Dict[str, str]
656663

657664
def __init__(
658665
self,
659666
transaction: Transaction,
660667
io: FileIO,
661-
branch: str,
668+
branch: Optional[str] = MAIN_BRANCH,
662669
snapshot_properties: Dict[str, str] = EMPTY_DICT,
663670
) -> None:
664671
self._transaction = transaction

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,3 +1151,61 @@ def test_append_multiple_partitions(
11511151
"""
11521152
)
11531153
assert files_df.count() == 6
1154+
1155+
1156+
@pytest.mark.integration
1157+
@pytest.mark.parametrize("format_version", [1, 2])
1158+
def test_stage_only_dynamic_partition_overwrite_files(
1159+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
1160+
) -> None:
1161+
identifier = f"default.test_stage_only_dynamic_partition_overwrite_files_v{format_version}"
1162+
try:
1163+
session_catalog.drop_table(identifier=identifier)
1164+
except NoSuchTableError:
1165+
pass
1166+
tbl = session_catalog.create_table(
1167+
identifier=identifier,
1168+
schema=TABLE_SCHEMA,
1169+
partition_spec=PartitionSpec(
1170+
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"),
1171+
PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"),
1172+
),
1173+
properties={"format-version": str(format_version)},
1174+
)
1175+
1176+
tbl.append(arrow_table_with_null)
1177+
current_snapshot = tbl.metadata.current_snapshot_id
1178+
assert current_snapshot is not None
1179+
1180+
original_count = len(tbl.scan().to_arrow())
1181+
assert original_count == 3
1182+
1183+
# write to staging snapshot
1184+
tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1), branch=None)
1185+
1186+
assert current_snapshot == tbl.metadata.current_snapshot_id
1187+
assert len(tbl.scan().to_arrow()) == original_count
1188+
snapshots = tbl.snapshots()
1189+
# dynamic partition overwrite will create 2 snapshots, one delete and another append
1190+
assert len(snapshots) == 3
1191+
1192+
# Write to main branch
1193+
tbl.append(arrow_table_with_null)
1194+
1195+
# Main ref has changed
1196+
assert current_snapshot != tbl.metadata.current_snapshot_id
1197+
assert len(tbl.scan().to_arrow()) == 6
1198+
snapshots = tbl.snapshots()
1199+
assert len(snapshots) == 4
1200+
1201+
rows = spark.sql(
1202+
f"""
1203+
SELECT operation, parent_id, snapshot_id
1204+
FROM {identifier}.snapshots
1205+
ORDER BY committed_at ASC
1206+
"""
1207+
).collect()
1208+
operations = [row.operation for row in rows]
1209+
parent_snapshot_id = [row.parent_id for row in rows]
1210+
assert operations == ["append", "delete", "append", "append"]
1211+
assert parent_snapshot_id == [None, current_snapshot, current_snapshot, current_snapshot]

0 commit comments

Comments
 (0)