Skip to content

Commit 0250e24

Browse files
author
Yingjian Wu
committed
implement stageOnly Commit
wip wip add test improve test rebase rebase
1 parent 8052652 commit 0250e24

File tree

6 files changed

+345
-58
lines changed

6 files changed

+345
-58
lines changed

pyiceberg/table/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
397397
expr = Or(expr, match_partition_expression)
398398
return expr
399399

400-
def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles:
400+
def _append_snapshot_producer(
401+
self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH
402+
) -> _FastAppendFiles:
401403
"""Determine the append type based on table properties.
402404
403405
Args:
@@ -430,15 +432,14 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
430432
name_mapping=self.table_metadata.name_mapping(),
431433
)
432434

433-
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
435+
def update_snapshot(
436+
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
437+
) -> UpdateSnapshot:
434438
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
435439
436440
Returns:
437441
A new UpdateSnapshot
438442
"""
439-
if branch is None:
440-
branch = MAIN_BRANCH
441-
442443
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
443444

444445
def update_statistics(self) -> UpdateStatistics:
@@ -450,7 +451,7 @@ def update_statistics(self) -> UpdateStatistics:
450451
"""
451452
return UpdateStatistics(transaction=self)
452453

453-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
454+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
454455
"""
455456
Shorthand API for appending a PyArrow table to a table transaction.
456457
@@ -495,7 +496,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
495496
append_files.append_data_file(data_file)
496497

497498
def dynamic_partition_overwrite(
498-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
499+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
499500
) -> None:
500501
"""
501502
Shorthand for overwriting existing partitions with a PyArrow table.
@@ -562,7 +563,7 @@ def overwrite(
562563
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
563564
snapshot_properties: Dict[str, str] = EMPTY_DICT,
564565
case_sensitive: bool = True,
565-
branch: Optional[str] = None,
566+
branch: Optional[str] = MAIN_BRANCH,
566567
) -> None:
567568
"""
568569
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
@@ -628,7 +629,7 @@ def delete(
628629
delete_filter: Union[str, BooleanExpression],
629630
snapshot_properties: Dict[str, str] = EMPTY_DICT,
630631
case_sensitive: bool = True,
631-
branch: Optional[str] = None,
632+
branch: Optional[str] = MAIN_BRANCH,
632633
) -> None:
633634
"""
634635
Shorthand for deleting record from a table.
@@ -731,7 +732,7 @@ def upsert(
731732
when_matched_update_all: bool = True,
732733
when_not_matched_insert_all: bool = True,
733734
case_sensitive: bool = True,
734-
branch: Optional[str] = None,
735+
branch: Optional[str] = MAIN_BRANCH,
735736
) -> UpsertResult:
736737
"""Shorthand API for performing an upsert to an iceberg table.
737738
@@ -816,7 +817,7 @@ def upsert(
816817
case_sensitive=case_sensitive,
817818
)
818819

819-
if branch is not None:
820+
if branch in self.table_metadata.refs:
820821
matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch)
821822

822823
matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader()
@@ -1307,7 +1308,7 @@ def upsert(
13071308
when_matched_update_all: bool = True,
13081309
when_not_matched_insert_all: bool = True,
13091310
case_sensitive: bool = True,
1310-
branch: Optional[str] = None,
1311+
branch: Optional[str] = MAIN_BRANCH,
13111312
) -> UpsertResult:
13121313
"""Shorthand API for performing an upsert to an iceberg table.
13131314
@@ -1354,7 +1355,7 @@ def upsert(
13541355
branch=branch,
13551356
)
13561357

1357-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
1358+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
13581359
"""
13591360
Shorthand API for appending a PyArrow table to the table.
13601361
@@ -1367,7 +1368,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
13671368
tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch)
13681369

13691370
def dynamic_partition_overwrite(
1370-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
1371+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
13711372
) -> None:
13721373
"""Shorthand for dynamic overwriting the table with a PyArrow table.
13731374
@@ -1386,7 +1387,7 @@ def overwrite(
13861387
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
13871388
snapshot_properties: Dict[str, str] = EMPTY_DICT,
13881389
case_sensitive: bool = True,
1389-
branch: Optional[str] = None,
1390+
branch: Optional[str] = MAIN_BRANCH,
13901391
) -> None:
13911392
"""
13921393
Shorthand for overwriting the table with a PyArrow table.
@@ -1419,7 +1420,7 @@ def delete(
14191420
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
14201421
snapshot_properties: Dict[str, str] = EMPTY_DICT,
14211422
case_sensitive: bool = True,
1422-
branch: Optional[str] = None,
1423+
branch: Optional[str] = MAIN_BRANCH,
14231424
) -> None:
14241425
"""
14251426
Shorthand for deleting rows from the table.

pyiceberg/table/metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,10 @@ def new_snapshot_id(self) -> int:
295295

296296
return snapshot_id
297297

298-
def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
298+
def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]:
299299
"""Return the snapshot referenced by the given name or null if no such reference exists."""
300+
if name is None:
301+
name = MAIN_BRANCH
300302
if ref := self.refs.get(name):
301303
return self.snapshot_by_id(ref.snapshot_id)
302304
return None

pyiceberg/table/update/snapshot.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
108108
_manifest_num_counter: itertools.count[int]
109109
_deleted_data_files: Set[DataFile]
110110
_compression: AvroCompressionCodec
111-
_target_branch = MAIN_BRANCH
111+
_target_branch: Optional[str]
112112

113113
def __init__(
114114
self,
@@ -117,7 +117,7 @@ def __init__(
117117
io: FileIO,
118118
commit_uuid: Optional[uuid.UUID] = None,
119119
snapshot_properties: Dict[str, str] = EMPTY_DICT,
120-
branch: str = MAIN_BRANCH,
120+
branch: Optional[str] = MAIN_BRANCH,
121121
) -> None:
122122
super().__init__(transaction)
123123
self.commit_uuid = commit_uuid or uuid.uuid4()
@@ -138,14 +138,13 @@ def __init__(
138138
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
139139
)
140140

141-
def _validate_target_branch(self, branch: str) -> str:
142-
# Default is already set to MAIN_BRANCH. So branch name can't be None.
143-
if branch is None:
144-
raise ValueError("Invalid branch name: null")
145-
if branch in self._transaction.table_metadata.refs:
146-
ref = self._transaction.table_metadata.refs[branch]
147-
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
148-
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
141+
def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]:
142+
# if branch is none, write will be written into a staging snapshot
143+
if branch is not None:
144+
if branch in self._transaction.table_metadata.refs:
145+
ref = self._transaction.table_metadata.refs[branch]
146+
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
147+
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
149148
return branch
150149

151150
def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
@@ -292,25 +291,33 @@ def _commit(self) -> UpdatesAndRequirements:
292291
schema_id=self._transaction.table_metadata.current_schema_id,
293292
)
294293

295-
return (
296-
(
297-
AddSnapshotUpdate(snapshot=snapshot),
298-
SetSnapshotRefUpdate(
299-
snapshot_id=self._snapshot_id,
300-
parent_snapshot_id=self._parent_snapshot_id,
301-
ref_name=self._target_branch,
302-
type=SnapshotRefType.BRANCH,
294+
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)
295+
296+
if self._target_branch is None:
297+
return (
298+
(add_snapshot_update,),
299+
(),
300+
)
301+
else:
302+
return (
303+
(
304+
add_snapshot_update,
305+
SetSnapshotRefUpdate(
306+
snapshot_id=self._snapshot_id,
307+
parent_snapshot_id=self._parent_snapshot_id,
308+
ref_name=self._target_branch,
309+
type=SnapshotRefType.BRANCH,
310+
),
303311
),
304-
),
305-
(
306-
AssertRefSnapshotId(
307-
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
308-
if self._target_branch in self._transaction.table_metadata.refs
309-
else None,
310-
ref=self._target_branch,
312+
(
313+
AssertRefSnapshotId(
314+
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
315+
if self._target_branch in self._transaction.table_metadata.refs
316+
else None,
317+
ref=self._target_branch,
318+
),
311319
),
312-
),
313-
)
320+
)
314321

315322
@property
316323
def snapshot_id(self) -> int:
@@ -357,7 +364,7 @@ def __init__(
357364
operation: Operation,
358365
transaction: Transaction,
359366
io: FileIO,
360-
branch: str,
367+
branch: Optional[str] = MAIN_BRANCH,
361368
commit_uuid: Optional[uuid.UUID] = None,
362369
snapshot_properties: Dict[str, str] = EMPTY_DICT,
363370
):
@@ -527,7 +534,7 @@ def __init__(
527534
operation: Operation,
528535
transaction: Transaction,
529536
io: FileIO,
530-
branch: str,
537+
branch: Optional[str] = MAIN_BRANCH,
531538
commit_uuid: Optional[uuid.UUID] = None,
532539
snapshot_properties: Dict[str, str] = EMPTY_DICT,
533540
) -> None:
@@ -648,14 +655,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
648655
class UpdateSnapshot:
649656
_transaction: Transaction
650657
_io: FileIO
651-
_branch: str
658+
_branch: Optional[str]
652659
_snapshot_properties: Dict[str, str]
653660

654661
def __init__(
655662
self,
656663
transaction: Transaction,
657664
io: FileIO,
658-
branch: str,
665+
branch: Optional[str] = MAIN_BRANCH,
659666
snapshot_properties: Dict[str, str] = EMPTY_DICT,
660667
) -> None:
661668
self._transaction = transaction

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,3 +1133,61 @@ def test_append_multiple_partitions(
11331133
"""
11341134
)
11351135
assert files_df.count() == 6
1136+
1137+
1138+
@pytest.mark.integration
1139+
@pytest.mark.parametrize("format_version", [1, 2])
1140+
def test_stage_only_dynamic_partition_overwrite_files(
1141+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
1142+
) -> None:
1143+
identifier = f"default.test_stage_only_dynamic_partition_overwrite_files_v{format_version}"
1144+
try:
1145+
session_catalog.drop_table(identifier=identifier)
1146+
except NoSuchTableError:
1147+
pass
1148+
tbl = session_catalog.create_table(
1149+
identifier=identifier,
1150+
schema=TABLE_SCHEMA,
1151+
partition_spec=PartitionSpec(
1152+
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"),
1153+
PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"),
1154+
),
1155+
properties={"format-version": str(format_version)},
1156+
)
1157+
1158+
tbl.append(arrow_table_with_null)
1159+
current_snapshot = tbl.metadata.current_snapshot_id
1160+
assert current_snapshot is not None
1161+
1162+
original_count = len(tbl.scan().to_arrow())
1163+
assert original_count == 3
1164+
1165+
# write to staging snapshot
1166+
tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1), branch=None)
1167+
1168+
assert current_snapshot == tbl.metadata.current_snapshot_id
1169+
assert len(tbl.scan().to_arrow()) == original_count
1170+
snapshots = tbl.snapshots()
1171+
# dynamic partition overwrite will create 2 snapshots, one delete and another append
1172+
assert len(snapshots) == 3
1173+
1174+
# Write to main branch
1175+
tbl.append(arrow_table_with_null)
1176+
1177+
# Main ref has changed
1178+
assert current_snapshot != tbl.metadata.current_snapshot_id
1179+
assert len(tbl.scan().to_arrow()) == 6
1180+
snapshots = tbl.snapshots()
1181+
assert len(snapshots) == 4
1182+
1183+
rows = spark.sql(
1184+
f"""
1185+
SELECT operation, parent_id, snapshot_id
1186+
FROM {identifier}.snapshots
1187+
ORDER BY committed_at ASC
1188+
"""
1189+
).collect()
1190+
operations = [row.operation for row in rows]
1191+
parent_snapshot_id = [row.parent_id for row in rows]
1192+
assert operations == ["append", "delete", "append", "append"]
1193+
assert parent_snapshot_id == [None, current_snapshot, current_snapshot, current_snapshot]

0 commit comments

Comments
 (0)