|
41 | 41 | from pyiceberg.catalog import Catalog, load_catalog
|
42 | 42 | from pyiceberg.catalog.hive import HiveCatalog
|
43 | 43 | from pyiceberg.catalog.sql import SqlCatalog
|
44 |
| -from pyiceberg.exceptions import NoSuchTableError |
| 44 | +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError |
45 | 45 | from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not
|
46 | 46 | from pyiceberg.io.pyarrow import _dataframe_to_data_files
|
47 | 47 | from pyiceberg.partitioning import PartitionField, PartitionSpec
|
48 | 48 | from pyiceberg.schema import Schema
|
49 | 49 | from pyiceberg.table import TableProperties
|
| 50 | +from pyiceberg.table.refs import MAIN_BRANCH |
50 | 51 | from pyiceberg.table.sorting import SortDirection, SortField, SortOrder
|
51 | 52 | from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
|
52 | 53 | from pyiceberg.types import (
|
@@ -1856,3 +1857,160 @@ def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null
|
1856 | 1857 | with tbl.io.new_input(current_snapshot.manifest_list).open() as f:
|
1857 | 1858 | reader = fastavro.reader(f)
|
1858 | 1859 | assert reader.codec == "null"
|
| 1860 | + |
| 1861 | + |
| 1862 | +@pytest.mark.integration |
| 1863 | +def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| 1864 | + identifier = "default.test_non_existing_branch" |
| 1865 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, []) |
| 1866 | + with pytest.raises( |
| 1867 | + CommitFailedException, match=f"Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH." |
| 1868 | + ): |
| 1869 | + tbl.append(arrow_table_with_null, branch="non_existing_branch") |
| 1870 | + |
| 1871 | + |
| 1872 | +@pytest.mark.integration |
| 1873 | +def test_append_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| 1874 | + identifier = "default.test_existing_branch_append" |
| 1875 | + branch = "existing_branch" |
| 1876 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| 1877 | + |
| 1878 | + assert tbl.metadata.current_snapshot_id is not None |
| 1879 | + |
| 1880 | + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| 1881 | + tbl.append(arrow_table_with_null, branch=branch) |
| 1882 | + |
| 1883 | + assert len(tbl.scan().use_ref(branch).to_arrow()) == 6 |
| 1884 | + assert len(tbl.scan().to_arrow()) == 3 |
| 1885 | + branch_snapshot = tbl.metadata.snapshot_by_name(branch) |
| 1886 | + assert branch_snapshot is not None |
| 1887 | + main_snapshot = tbl.metadata.snapshot_by_name("main") |
| 1888 | + assert main_snapshot is not None |
| 1889 | + assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id |
| 1890 | + |
| 1891 | + |
| 1892 | +@pytest.mark.integration |
| 1893 | +def test_delete_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| 1894 | + identifier = "default.test_existing_branch_delete" |
| 1895 | + branch = "existing_branch" |
| 1896 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| 1897 | + |
| 1898 | + assert tbl.metadata.current_snapshot_id is not None |
| 1899 | + |
| 1900 | + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| 1901 | + tbl.delete(delete_filter="int = 9", branch=branch) |
| 1902 | + |
| 1903 | + assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 |
| 1904 | + assert len(tbl.scan().to_arrow()) == 3 |
| 1905 | + branch_snapshot = tbl.metadata.snapshot_by_name(branch) |
| 1906 | + assert branch_snapshot is not None |
| 1907 | + main_snapshot = tbl.metadata.snapshot_by_name("main") |
| 1908 | + assert main_snapshot is not None |
| 1909 | + assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id |
| 1910 | + |
| 1911 | + |
| 1912 | +@pytest.mark.integration |
| 1913 | +def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| 1914 | + identifier = "default.test_existing_branch_overwrite" |
| 1915 | + branch = "existing_branch" |
| 1916 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| 1917 | + |
| 1918 | + assert tbl.metadata.current_snapshot_id is not None |
| 1919 | + |
| 1920 | + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| 1921 | + tbl.overwrite(arrow_table_with_null, branch=branch) |
| 1922 | + |
| 1923 | + assert len(tbl.scan().use_ref(branch).to_arrow()) == 3 |
| 1924 | + assert len(tbl.scan().to_arrow()) == 3 |
| 1925 | + branch_snapshot = tbl.metadata.snapshot_by_name(branch) |
| 1926 | + assert branch_snapshot is not None and branch_snapshot.parent_snapshot_id is not None |
| 1927 | + delete_snapshot = tbl.metadata.snapshot_by_id(branch_snapshot.parent_snapshot_id) |
| 1928 | + assert delete_snapshot is not None |
| 1929 | + main_snapshot = tbl.metadata.snapshot_by_name("main") |
| 1930 | + assert main_snapshot is not None |
| 1931 | + assert ( |
| 1932 | + delete_snapshot.parent_snapshot_id == main_snapshot.snapshot_id |
| 1933 | + ) # Currently overwrite is a delete followed by an append operation |
| 1934 | + |
| 1935 | + |
| 1936 | +@pytest.mark.integration |
| 1937 | +def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| 1938 | + identifier = "default.test_intertwined_branch_operations" |
| 1939 | + branch1 = "existing_branch_1" |
| 1940 | + branch2 = "existing_branch_2" |
| 1941 | + |
| 1942 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| 1943 | + |
| 1944 | + assert tbl.metadata.current_snapshot_id is not None |
| 1945 | + |
| 1946 | + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch1).commit() |
| 1947 | + |
| 1948 | + tbl.delete("int = 9", branch=branch1) |
| 1949 | + |
| 1950 | + tbl.append(arrow_table_with_null) |
| 1951 | + |
| 1952 | + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch2).commit() |
| 1953 | + |
| 1954 | + tbl.overwrite(arrow_table_with_null, branch=branch2) |
| 1955 | + |
| 1956 | + assert len(tbl.scan().use_ref(branch1).to_arrow()) == 2 |
| 1957 | + assert len(tbl.scan().use_ref(branch2).to_arrow()) == 3 |
| 1958 | + assert len(tbl.scan().to_arrow()) == 6 |
| 1959 | + |
| 1960 | + |
| 1961 | +@pytest.mark.integration |
| 1962 | +def test_branch_spark_write_py_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: |
| 1963 | + # Initialize table with branch |
| 1964 | + identifier = "default.test_branch_spark_write_py_read" |
| 1965 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| 1966 | + branch = "existing_spark_branch" |
| 1967 | + |
| 1968 | + # Create branch in Spark |
| 1969 | + spark.sql(f"ALTER TABLE {identifier} CREATE BRANCH {branch}") |
| 1970 | + |
| 1971 | + # Spark Write |
| 1972 | + spark.sql( |
| 1973 | + f""" |
| 1974 | + DELETE FROM {identifier}.branch_{branch} |
| 1975 | + WHERE int = 9 |
| 1976 | + """ |
| 1977 | + ) |
| 1978 | + |
| 1979 | + # Refresh table to get new refs |
| 1980 | + tbl.refresh() |
| 1981 | + |
| 1982 | + # Python Read |
| 1983 | + assert len(tbl.scan().to_arrow()) == 3 |
| 1984 | + assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 |
| 1985 | + |
| 1986 | + |
| 1987 | +@pytest.mark.integration |
| 1988 | +def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: |
| 1989 | + # Initialize table with branch |
| 1990 | + identifier = "default.test_branch_py_write_spark_read" |
| 1991 | + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| 1992 | + branch = "existing_py_branch" |
| 1993 | + |
| 1994 | + assert tbl.metadata.current_snapshot_id is not None |
| 1995 | + |
| 1996 | + # Create branch |
| 1997 | + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| 1998 | + |
| 1999 | + # Python Write |
| 2000 | + tbl.delete("int = 9", branch=branch) |
| 2001 | + |
| 2002 | + # Spark Read |
| 2003 | + main_df = spark.sql( |
| 2004 | + f""" |
| 2005 | + SELECT * |
| 2006 | + FROM {identifier} |
| 2007 | + """ |
| 2008 | + ) |
| 2009 | + branch_df = spark.sql( |
| 2010 | + f""" |
| 2011 | + SELECT * |
| 2012 | + FROM {identifier}.branch_{branch} |
| 2013 | + """ |
| 2014 | + ) |
| 2015 | + assert main_df.count() == 3 |
| 2016 | + assert branch_df.count() == 2 |
0 commit comments