Skip to content

Commit 33c8931

Browse files
vinjaikevinjqliu
andauthored
Feature: Write to branches (#941)
Fixes: #306 --------- Co-authored-by: Kevin Liu <[email protected]>
1 parent e054b77 commit 33c8931

File tree

8 files changed

+409
-111
lines changed

8 files changed

+409
-111
lines changed

pyiceberg/cli/console.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pyiceberg.cli.output import ConsoleOutput, JsonOutput, Output
3434
from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchPropertyException, NoSuchTableError
3535
from pyiceberg.table import TableProperties
36-
from pyiceberg.table.refs import SnapshotRef
36+
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
3737
from pyiceberg.utils.properties import property_as_int
3838

3939

@@ -417,7 +417,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None:
417417
refs = table.refs()
418418
if type:
419419
type = type.lower()
420-
if type not in {"branch", "tag"}:
420+
if type not in {SnapshotRefType.BRANCH, SnapshotRefType.TAG}:
421421
raise ValueError(f"Type must be either branch or tag, got: {type}")
422422

423423
relevant_refs = [
@@ -431,7 +431,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None:
431431

432432
def _retention_properties(ref: SnapshotRef, table_properties: Dict[str, str]) -> Dict[str, str]:
433433
retention_properties = {}
434-
if ref.snapshot_ref_type == "branch":
434+
if ref.snapshot_ref_type == SnapshotRefType.BRANCH:
435435
default_min_snapshots_to_keep = property_as_int(
436436
table_properties,
437437
TableProperties.MIN_SNAPSHOTS_TO_KEEP,

pyiceberg/table/__init__.py

Lines changed: 72 additions & 25 deletions
Large diffs are not rendered by default.

pyiceberg/table/update/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec
3030
from pyiceberg.schema import Schema
3131
from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil
32-
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
32+
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
3333
from pyiceberg.table.snapshots import (
3434
MetadataLogEntry,
3535
Snapshot,
@@ -139,7 +139,7 @@ class AddSnapshotUpdate(IcebergBaseModel):
139139
class SetSnapshotRefUpdate(IcebergBaseModel):
140140
action: Literal["set-snapshot-ref"] = Field(default="set-snapshot-ref")
141141
ref_name: str = Field(alias="ref-name")
142-
type: Literal["tag", "branch"]
142+
type: Literal[SnapshotRefType.TAG, SnapshotRefType.BRANCH]
143143
snapshot_id: int = Field(alias="snapshot-id")
144144
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)]
145145
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)]
@@ -702,6 +702,10 @@ class AssertRefSnapshotId(ValidatableTableRequirement):
702702
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
703703
if base_metadata is None:
704704
raise CommitFailedException("Requirement failed: current table metadata is missing")
705+
elif len(base_metadata.snapshots) == 0 and self.ref != MAIN_BRANCH:
706+
raise CommitFailedException(
707+
f"Requirement failed: Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH."
708+
)
705709
elif snapshot_ref := base_metadata.refs.get(self.ref):
706710
ref_type = snapshot_ref.snapshot_ref_type
707711
if self.snapshot_id is None:

pyiceberg/table/update/snapshot.py

Lines changed: 112 additions & 66 deletions
Large diffs are not rendered by default.

pyiceberg/utils/concurrent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
class ExecutorFactory:
2626
_instance: Optional[Executor] = None
2727

28+
@staticmethod
29+
def max_workers() -> Optional[int]:
30+
"""Return the max number of workers configured."""
31+
return Config().get_int("max-workers")
32+
2833
@staticmethod
2934
def get_or_create() -> Executor:
3035
"""Return the same executor in each call."""
@@ -33,8 +38,3 @@ def get_or_create() -> Executor:
3338
ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers)
3439

3540
return ExecutorFactory._instance
36-
37-
@staticmethod
38-
def max_workers() -> Optional[int]:
39-
"""Return the max number of workers configured."""
40-
return Config().get_int("max-workers")

tests/integration/test_deletes.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,3 +894,32 @@ def test_overwrite_with_filter_case_insensitive(test_table: Table) -> None:
894894
test_table.overwrite(df=new_table, overwrite_filter=f"Idx == {record_to_overwrite['idx']}", case_sensitive=False)
895895
assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
896896
assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
897+
898+
899+
@pytest.mark.integration
900+
@pytest.mark.parametrize("format_version", [1, 2])
901+
@pytest.mark.filterwarnings("ignore:Delete operation did not match any records")
902+
def test_delete_on_empty_table(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
903+
identifier = f"default.test_delete_on_empty_table_{format_version}"
904+
905+
run_spark_commands(
906+
spark,
907+
[
908+
f"DROP TABLE IF EXISTS {identifier}",
909+
f"""
910+
CREATE TABLE {identifier} (
911+
volume int
912+
)
913+
USING iceberg
914+
TBLPROPERTIES('format-version' = {format_version})
915+
""",
916+
],
917+
)
918+
919+
tbl = session_catalog.load_table(identifier)
920+
921+
# Perform a delete operation on the empty table
922+
tbl.delete(AlwaysTrue())
923+
924+
# Assert that no new snapshot was created because no rows were deleted
925+
assert len(tbl.snapshots()) == 0

tests/integration/test_writes/test_writes.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@
4141
from pyiceberg.catalog import Catalog, load_catalog
4242
from pyiceberg.catalog.hive import HiveCatalog
4343
from pyiceberg.catalog.sql import SqlCatalog
44-
from pyiceberg.exceptions import NoSuchTableError
44+
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError
4545
from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not
4646
from pyiceberg.io.pyarrow import _dataframe_to_data_files
4747
from pyiceberg.partitioning import PartitionField, PartitionSpec
4848
from pyiceberg.schema import Schema
4949
from pyiceberg.table import TableProperties
50+
from pyiceberg.table.refs import MAIN_BRANCH
5051
from pyiceberg.table.sorting import SortDirection, SortField, SortOrder
5152
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
5253
from pyiceberg.types import (
@@ -1856,3 +1857,160 @@ def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null
18561857
with tbl.io.new_input(current_snapshot.manifest_list).open() as f:
18571858
reader = fastavro.reader(f)
18581859
assert reader.codec == "null"
1860+
1861+
1862+
@pytest.mark.integration
1863+
def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1864+
identifier = "default.test_non_existing_branch"
1865+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [])
1866+
with pytest.raises(
1867+
CommitFailedException, match=f"Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH."
1868+
):
1869+
tbl.append(arrow_table_with_null, branch="non_existing_branch")
1870+
1871+
1872+
@pytest.mark.integration
1873+
def test_append_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1874+
identifier = "default.test_existing_branch_append"
1875+
branch = "existing_branch"
1876+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1877+
1878+
assert tbl.metadata.current_snapshot_id is not None
1879+
1880+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1881+
tbl.append(arrow_table_with_null, branch=branch)
1882+
1883+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 6
1884+
assert len(tbl.scan().to_arrow()) == 3
1885+
branch_snapshot = tbl.metadata.snapshot_by_name(branch)
1886+
assert branch_snapshot is not None
1887+
main_snapshot = tbl.metadata.snapshot_by_name("main")
1888+
assert main_snapshot is not None
1889+
assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id
1890+
1891+
1892+
@pytest.mark.integration
1893+
def test_delete_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1894+
identifier = "default.test_existing_branch_delete"
1895+
branch = "existing_branch"
1896+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1897+
1898+
assert tbl.metadata.current_snapshot_id is not None
1899+
1900+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1901+
tbl.delete(delete_filter="int = 9", branch=branch)
1902+
1903+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 2
1904+
assert len(tbl.scan().to_arrow()) == 3
1905+
branch_snapshot = tbl.metadata.snapshot_by_name(branch)
1906+
assert branch_snapshot is not None
1907+
main_snapshot = tbl.metadata.snapshot_by_name("main")
1908+
assert main_snapshot is not None
1909+
assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id
1910+
1911+
1912+
@pytest.mark.integration
1913+
def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1914+
identifier = "default.test_existing_branch_overwrite"
1915+
branch = "existing_branch"
1916+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1917+
1918+
assert tbl.metadata.current_snapshot_id is not None
1919+
1920+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1921+
tbl.overwrite(arrow_table_with_null, branch=branch)
1922+
1923+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 3
1924+
assert len(tbl.scan().to_arrow()) == 3
1925+
branch_snapshot = tbl.metadata.snapshot_by_name(branch)
1926+
assert branch_snapshot is not None and branch_snapshot.parent_snapshot_id is not None
1927+
delete_snapshot = tbl.metadata.snapshot_by_id(branch_snapshot.parent_snapshot_id)
1928+
assert delete_snapshot is not None
1929+
main_snapshot = tbl.metadata.snapshot_by_name("main")
1930+
assert main_snapshot is not None
1931+
assert (
1932+
delete_snapshot.parent_snapshot_id == main_snapshot.snapshot_id
1933+
) # Currently overwrite is a delete followed by an append operation
1934+
1935+
1936+
@pytest.mark.integration
1937+
def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
1938+
identifier = "default.test_intertwined_branch_operations"
1939+
branch1 = "existing_branch_1"
1940+
branch2 = "existing_branch_2"
1941+
1942+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1943+
1944+
assert tbl.metadata.current_snapshot_id is not None
1945+
1946+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch1).commit()
1947+
1948+
tbl.delete("int = 9", branch=branch1)
1949+
1950+
tbl.append(arrow_table_with_null)
1951+
1952+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch2).commit()
1953+
1954+
tbl.overwrite(arrow_table_with_null, branch=branch2)
1955+
1956+
assert len(tbl.scan().use_ref(branch1).to_arrow()) == 2
1957+
assert len(tbl.scan().use_ref(branch2).to_arrow()) == 3
1958+
assert len(tbl.scan().to_arrow()) == 6
1959+
1960+
1961+
@pytest.mark.integration
1962+
def test_branch_spark_write_py_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None:
1963+
# Initialize table with branch
1964+
identifier = "default.test_branch_spark_write_py_read"
1965+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1966+
branch = "existing_spark_branch"
1967+
1968+
# Create branch in Spark
1969+
spark.sql(f"ALTER TABLE {identifier} CREATE BRANCH {branch}")
1970+
1971+
# Spark Write
1972+
spark.sql(
1973+
f"""
1974+
DELETE FROM {identifier}.branch_{branch}
1975+
WHERE int = 9
1976+
"""
1977+
)
1978+
1979+
# Refresh table to get new refs
1980+
tbl.refresh()
1981+
1982+
# Python Read
1983+
assert len(tbl.scan().to_arrow()) == 3
1984+
assert len(tbl.scan().use_ref(branch).to_arrow()) == 2
1985+
1986+
1987+
@pytest.mark.integration
1988+
def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None:
1989+
# Initialize table with branch
1990+
identifier = "default.test_branch_py_write_spark_read"
1991+
tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null])
1992+
branch = "existing_py_branch"
1993+
1994+
assert tbl.metadata.current_snapshot_id is not None
1995+
1996+
# Create branch
1997+
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()
1998+
1999+
# Python Write
2000+
tbl.delete("int = 9", branch=branch)
2001+
2002+
# Spark Read
2003+
main_df = spark.sql(
2004+
f"""
2005+
SELECT *
2006+
FROM {identifier}
2007+
"""
2008+
)
2009+
branch_df = spark.sql(
2010+
f"""
2011+
SELECT *
2012+
FROM {identifier}.branch_{branch}
2013+
"""
2014+
)
2015+
assert main_df.count() == 3
2016+
assert branch_df.count() == 2

tests/table/test_init.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
_match_deletes_to_data_file,
5151
)
5252
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id
53-
from pyiceberg.table.refs import SnapshotRef
53+
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
5454
from pyiceberg.table.snapshots import (
5555
MetadataLogEntry,
5656
Operation,
@@ -1000,28 +1000,42 @@ def test_assert_table_uuid(table_v2: Table) -> None:
10001000

10011001
def test_assert_ref_snapshot_id(table_v2: Table) -> None:
10021002
base_metadata = table_v2.metadata
1003-
AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata)
1003+
AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata)
10041004

10051005
with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
1006-
AssertRefSnapshotId(ref="main", snapshot_id=1).validate(None)
1006+
AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=1).validate(None)
10071007

10081008
with pytest.raises(
10091009
CommitFailedException,
1010-
match="Requirement failed: branch main was created concurrently",
1010+
match=f"Requirement failed: branch {MAIN_BRANCH} was created concurrently",
10111011
):
1012-
AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata)
1012+
AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=None).validate(base_metadata)
10131013

10141014
with pytest.raises(
10151015
CommitFailedException,
1016-
match="Requirement failed: branch main has changed: expected id 1, found 3055729675574597004",
1016+
match=f"Requirement failed: branch {MAIN_BRANCH} has changed: expected id 1, found 3055729675574597004",
10171017
):
1018-
AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata)
1018+
AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=1).validate(base_metadata)
1019+
1020+
non_existing_ref = "not_exist_branch_or_tag"
1021+
assert table_v2.refs().get("not_exist_branch_or_tag") is None
1022+
1023+
with pytest.raises(
1024+
CommitFailedException,
1025+
match=f"Requirement failed: branch or tag {non_existing_ref} is missing, expected 1",
1026+
):
1027+
AssertRefSnapshotId(ref=non_existing_ref, snapshot_id=1).validate(base_metadata)
1028+
1029+
# existing Tag in metadata: test
1030+
ref_tag = table_v2.refs().get("test")
1031+
assert ref_tag is not None
1032+
assert ref_tag.snapshot_ref_type == SnapshotRefType.TAG, "TAG test should be present in table to be tested"
10191033

10201034
with pytest.raises(
10211035
CommitFailedException,
1022-
match="Requirement failed: branch or tag not_exist is missing, expected 1",
1036+
match="Requirement failed: tag test has changed: expected id 3055729675574597004, found 3051729675574597004",
10231037
):
1024-
AssertRefSnapshotId(ref="not_exist", snapshot_id=1).validate(base_metadata)
1038+
AssertRefSnapshotId(ref="test", snapshot_id=3055729675574597004).validate(base_metadata)
10251039

10261040

10271041
def test_assert_last_assigned_field_id(table_v2: Table) -> None:

0 commit comments

Comments
 (0)