Skip to content

Commit 404b2cd

Browse files
author
Yingjian Wu
committed
wip
1 parent 1958e5c commit 404b2cd

File tree

4 files changed

+37
-51
lines changed

4 files changed

+37
-51
lines changed

pyiceberg/table/__init__.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
399399
expr = Or(expr, match_partition_expression)
400400
return expr
401401

402-
def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles:
402+
def _append_snapshot_producer(
403+
self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH
404+
) -> _FastAppendFiles:
403405
"""Determine the append type based on table properties.
404406
405407
Args:
@@ -433,19 +435,14 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
433435
)
434436

435437
def update_snapshot(
436-
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False
438+
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
437439
) -> UpdateSnapshot:
438440
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
439441
440442
Returns:
441443
A new UpdateSnapshot
442444
"""
443-
if branch is None:
444-
branch = MAIN_BRANCH
445-
446-
return UpdateSnapshot(
447-
self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only
448-
)
445+
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
449446

450447
def update_statistics(self) -> UpdateStatistics:
451448
"""
@@ -456,7 +453,7 @@ def update_statistics(self) -> UpdateStatistics:
456453
"""
457454
return UpdateStatistics(transaction=self)
458455

459-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
456+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
460457
"""
461458
Shorthand API for appending a PyArrow table to a table transaction.
462459
@@ -498,7 +495,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
498495
append_files.append_data_file(data_file)
499496

500497
def dynamic_partition_overwrite(
501-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
498+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
502499
) -> None:
503500
"""
504501
Shorthand for overwriting existing partitions with a PyArrow table.
@@ -562,7 +559,7 @@ def overwrite(
562559
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
563560
snapshot_properties: Dict[str, str] = EMPTY_DICT,
564561
case_sensitive: bool = True,
565-
branch: Optional[str] = None,
562+
branch: Optional[str] = MAIN_BRANCH,
566563
) -> None:
567564
"""
568565
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
@@ -625,7 +622,7 @@ def delete(
625622
delete_filter: Union[str, BooleanExpression],
626623
snapshot_properties: Dict[str, str] = EMPTY_DICT,
627624
case_sensitive: bool = True,
628-
branch: Optional[str] = None,
625+
branch: Optional[str] = MAIN_BRANCH,
629626
) -> None:
630627
"""
631628
Shorthand for deleting record from a table.
@@ -728,7 +725,7 @@ def upsert(
728725
when_matched_update_all: bool = True,
729726
when_not_matched_insert_all: bool = True,
730727
case_sensitive: bool = True,
731-
branch: Optional[str] = None,
728+
branch: Optional[str] = MAIN_BRANCH,
732729
) -> UpsertResult:
733730
"""Shorthand API for performing an upsert to an iceberg table.
734731
@@ -1295,7 +1292,7 @@ def upsert(
12951292
when_matched_update_all: bool = True,
12961293
when_not_matched_insert_all: bool = True,
12971294
case_sensitive: bool = True,
1298-
branch: Optional[str] = None,
1295+
branch: Optional[str] = MAIN_BRANCH,
12991296
) -> UpsertResult:
13001297
"""Shorthand API for performing an upsert to an iceberg table.
13011298
@@ -1342,7 +1339,7 @@ def upsert(
13421339
branch=branch,
13431340
)
13441341

1345-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
1342+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
13461343
"""
13471344
Shorthand API for appending a PyArrow table to the table.
13481345
@@ -1355,7 +1352,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
13551352
tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch)
13561353

13571354
def dynamic_partition_overwrite(
1358-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
1355+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH
13591356
) -> None:
13601357
"""Shorthand for dynamic overwriting the table with a PyArrow table.
13611358
@@ -1374,7 +1371,7 @@ def overwrite(
13741371
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
13751372
snapshot_properties: Dict[str, str] = EMPTY_DICT,
13761373
case_sensitive: bool = True,
1377-
branch: Optional[str] = None,
1374+
branch: Optional[str] = MAIN_BRANCH,
13781375
) -> None:
13791376
"""
13801377
Shorthand for overwriting the table with a PyArrow table.
@@ -1407,7 +1404,7 @@ def delete(
14071404
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
14081405
snapshot_properties: Dict[str, str] = EMPTY_DICT,
14091406
case_sensitive: bool = True,
1410-
branch: Optional[str] = None,
1407+
branch: Optional[str] = MAIN_BRANCH,
14111408
) -> None:
14121409
"""
14131410
Shorthand for deleting rows from the table.

pyiceberg/table/metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,10 @@ def new_snapshot_id(self) -> int:
295295

296296
return snapshot_id
297297

298-
def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
298+
def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]:
299299
"""Return the snapshot referenced by the given name or null if no such reference exists."""
300+
if name is None:
301+
name = MAIN_BRANCH
300302
if ref := self.refs.get(name):
301303
return self.snapshot_by_id(ref.snapshot_id)
302304
return None

pyiceberg/table/update/snapshot.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
108108
_manifest_num_counter: itertools.count[int]
109109
_deleted_data_files: Set[DataFile]
110110
_compression: AvroCompressionCodec
111-
_target_branch = MAIN_BRANCH
112-
_stage_only = False
111+
_target_branch: Optional[str]
113112

114113
def __init__(
115114
self,
@@ -118,8 +117,7 @@ def __init__(
118117
io: FileIO,
119118
commit_uuid: Optional[uuid.UUID] = None,
120119
snapshot_properties: Dict[str, str] = EMPTY_DICT,
121-
branch: str = MAIN_BRANCH,
122-
stage_only: bool = False,
120+
branch: Optional[str] = MAIN_BRANCH,
123121
) -> None:
124122
super().__init__(transaction)
125123
self.commit_uuid = commit_uuid or uuid.uuid4()
@@ -139,16 +137,14 @@ def __init__(
139137
self._parent_snapshot_id = (
140138
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
141139
)
142-
self._stage_only = stage_only
143140

144-
def _validate_target_branch(self, branch: str) -> str:
141+
def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]:
145142
# Default is already set to MAIN_BRANCH. So branch name can't be None.
146-
if branch is None:
147-
raise ValueError("Invalid branch name: null")
148-
if branch in self._transaction.table_metadata.refs:
149-
ref = self._transaction.table_metadata.refs[branch]
150-
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
151-
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
143+
if branch is not None:
144+
if branch in self._transaction.table_metadata.refs:
145+
ref = self._transaction.table_metadata.refs[branch]
146+
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
147+
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
152148
return branch
153149

154150
def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
@@ -297,7 +293,7 @@ def _commit(self) -> UpdatesAndRequirements:
297293

298294
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)
299295

300-
if self._stage_only:
296+
if self._target_branch is None:
301297
return (
302298
(add_snapshot_update,),
303299
(),
@@ -368,12 +364,11 @@ def __init__(
368364
operation: Operation,
369365
transaction: Transaction,
370366
io: FileIO,
371-
branch: str,
367+
branch: Optional[str] = MAIN_BRANCH,
372368
commit_uuid: Optional[uuid.UUID] = None,
373369
snapshot_properties: Dict[str, str] = EMPTY_DICT,
374-
stage_only: bool = False,
375370
):
376-
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
371+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
377372
self._predicate = AlwaysFalse()
378373
self._case_sensitive = True
379374

@@ -539,14 +534,13 @@ def __init__(
539534
operation: Operation,
540535
transaction: Transaction,
541536
io: FileIO,
542-
branch: str,
537+
branch: Optional[str] = MAIN_BRANCH,
543538
commit_uuid: Optional[uuid.UUID] = None,
544539
snapshot_properties: Dict[str, str] = EMPTY_DICT,
545-
stage_only: bool = False,
546540
) -> None:
547541
from pyiceberg.table import TableProperties
548542

549-
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
543+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
550544
self._target_size_bytes = property_as_int(
551545
self._transaction.table_metadata.properties,
552546
TableProperties.MANIFEST_TARGET_SIZE_BYTES,
@@ -661,23 +655,20 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
661655
class UpdateSnapshot:
662656
_transaction: Transaction
663657
_io: FileIO
664-
_branch: str
665-
_stage_only: bool
658+
_branch: Optional[str]
666659
_snapshot_properties: Dict[str, str]
667660

668661
def __init__(
669662
self,
670663
transaction: Transaction,
671664
io: FileIO,
672-
branch: str,
673-
stage_only: bool = False,
665+
branch: Optional[str] = MAIN_BRANCH,
674666
snapshot_properties: Dict[str, str] = EMPTY_DICT,
675667
) -> None:
676668
self._transaction = transaction
677669
self._io = io
678670
self._snapshot_properties = snapshot_properties
679671
self._branch = branch
680-
self._stage_only = stage_only
681672

682673
def fast_append(self) -> _FastAppendFiles:
683674
return _FastAppendFiles(
@@ -686,7 +677,6 @@ def fast_append(self) -> _FastAppendFiles:
686677
io=self._io,
687678
branch=self._branch,
688679
snapshot_properties=self._snapshot_properties,
689-
stage_only=self._stage_only,
690680
)
691681

692682
def merge_append(self) -> _MergeAppendFiles:
@@ -696,7 +686,6 @@ def merge_append(self) -> _MergeAppendFiles:
696686
io=self._io,
697687
branch=self._branch,
698688
snapshot_properties=self._snapshot_properties,
699-
stage_only=self._stage_only,
700689
)
701690

702691
def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
@@ -709,7 +698,6 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
709698
io=self._io,
710699
branch=self._branch,
711700
snapshot_properties=self._snapshot_properties,
712-
stage_only=self._stage_only,
713701
)
714702

715703
def delete(self) -> _DeleteFiles:
@@ -719,7 +707,6 @@ def delete(self) -> _DeleteFiles:
719707
io=self._io,
720708
branch=self._branch,
721709
snapshot_properties=self._snapshot_properties,
722-
stage_only=self._stage_only,
723710
)
724711

725712

tests/integration/test_writes/test_writes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,7 +2125,7 @@ def test_stage_only_delete(
21252125
assert len(files_to_delete) > 0
21262126

21272127
with tbl.transaction() as txn:
2128-
with txn.update_snapshot(stage_only=True).delete() as delete:
2128+
with txn.update_snapshot(branch=None).delete() as delete:
21292129
delete.delete_by_predicate(EqualTo("int", 9))
21302130

21312131
# a new delete snapshot is added
@@ -2162,7 +2162,7 @@ def test_stage_only_fast_append(
21622162
assert original_count == 3
21632163

21642164
with tbl.transaction() as txn:
2165-
with txn.update_snapshot(stage_only=True).fast_append() as fast_append:
2165+
with txn.update_snapshot(branch=None).fast_append() as fast_append:
21662166
for data_file in _dataframe_to_data_files(
21672167
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
21682168
):
@@ -2202,7 +2202,7 @@ def test_stage_only_merge_append(
22022202
assert original_count == 3
22032203

22042204
with tbl.transaction() as txn:
2205-
with txn.update_snapshot(stage_only=True).merge_append() as merge_append:
2205+
with txn.update_snapshot(branch=None).merge_append() as merge_append:
22062206
for data_file in _dataframe_to_data_files(
22072207
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
22082208
):
@@ -2247,7 +2247,7 @@ def test_stage_only_overwrite_files(
22472247
assert len(files_to_delete) > 0
22482248

22492249
with tbl.transaction() as txn:
2250-
with txn.update_snapshot(stage_only=True).overwrite() as overwrite:
2250+
with txn.update_snapshot(branch=None).overwrite() as overwrite:
22512251
for data_file in _dataframe_to_data_files(
22522252
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
22532253
):

0 commit comments

Comments
 (0)