Skip to content

Commit dad5a1e

Browse files
author
Yingjian Wu
committed
improve test
rebase
1 parent 3a942cf commit dad5a1e

File tree

3 files changed

+169
-129
lines changed

3 files changed

+169
-129
lines changed

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,3 +1133,61 @@ def test_append_multiple_partitions(
11331133
"""
11341134
)
11351135
assert files_df.count() == 6
1136+
1137+
1138+
@pytest.mark.integration
1139+
@pytest.mark.parametrize("format_version", [1, 2])
1140+
def test_stage_only_dynamic_partition_overwrite_files(
1141+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
1142+
) -> None:
1143+
identifier = f"default.test_stage_only_dynamic_partition_overwrite_files_v{format_version}"
1144+
try:
1145+
session_catalog.drop_table(identifier=identifier)
1146+
except NoSuchTableError:
1147+
pass
1148+
tbl = session_catalog.create_table(
1149+
identifier=identifier,
1150+
schema=TABLE_SCHEMA,
1151+
partition_spec=PartitionSpec(
1152+
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"),
1153+
PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"),
1154+
),
1155+
properties={"format-version": str(format_version)},
1156+
)
1157+
1158+
tbl.append(arrow_table_with_null)
1159+
current_snapshot = tbl.metadata.current_snapshot_id
1160+
assert current_snapshot is not None
1161+
1162+
original_count = len(tbl.scan().to_arrow())
1163+
assert original_count == 3
1164+
1165+
# write to staging snapshot
1166+
tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1), branch=None)
1167+
1168+
assert current_snapshot == tbl.metadata.current_snapshot_id
1169+
assert len(tbl.scan().to_arrow()) == original_count
1170+
snapshots = tbl.snapshots()
1171+
# dynamic partition overwrite will create 2 snapshots, one delete and another append
1172+
assert len(snapshots) == 3
1173+
1174+
# Write to main branch
1175+
tbl.append(arrow_table_with_null)
1176+
1177+
# Main ref has changed
1178+
assert current_snapshot != tbl.metadata.current_snapshot_id
1179+
assert len(tbl.scan().to_arrow()) == 6
1180+
snapshots = tbl.snapshots()
1181+
assert len(snapshots) == 4
1182+
1183+
rows = spark.sql(
1184+
f"""
1185+
SELECT operation, parent_id, snapshot_id
1186+
FROM {identifier}.snapshots
1187+
ORDER BY committed_at ASC
1188+
"""
1189+
).collect()
1190+
operations = [row.operation for row in rows]
1191+
parent_snapshot_id = [row.parent_id for row in rows]
1192+
assert operations == ["append", "delete", "append", "append"]
1193+
assert parent_snapshot_id == [None, current_snapshot, current_snapshot, current_snapshot]

tests/integration/test_writes/test_writes.py

Lines changed: 47 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
import os
2020
import random
21+
import re
2122
import time
2223
import uuid
2324
from datetime import date, datetime, timedelta
@@ -44,7 +45,7 @@
4445
from pyiceberg.catalog.sql import SqlCatalog
4546
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError
4647
from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not
47-
from pyiceberg.io.pyarrow import _dataframe_to_data_files
48+
from pyiceberg.io.pyarrow import _dataframe_to_data_files, UnsupportedPyArrowTypeException
4849
from pyiceberg.partitioning import PartitionField, PartitionSpec
4950
from pyiceberg.schema import Schema
5051
from pyiceberg.table import TableProperties
@@ -2249,18 +2250,27 @@ def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSessio
22492250

22502251

22512252
@pytest.mark.integration
2252-
def test_nanosecond_support_on_catalog(session_catalog: Catalog) -> None:
2253+
def test_nanosecond_support_on_catalog(
2254+
session_catalog: Catalog, arrow_table_schema_with_all_timestamp_precisions: pa.Schema
2255+
) -> None:
22532256
identifier = "default.test_nanosecond_support_on_catalog"
2254-
# Create a pyarrow table with a nanosecond timestamp column
2255-
table = pa.Table.from_arrays(
2256-
[
2257-
pa.array([datetime.now()], type=pa.timestamp("ns")),
2258-
pa.array([datetime.now()], type=pa.timestamp("ns", tz="America/New_York")),
2259-
],
2260-
names=["timestamp_ns", "timestamptz_ns"],
2261-
)
22622257

2263-
_create_table(session_catalog, identifier, {"format-version": "3"}, schema=table.schema)
2258+
catalog = load_catalog("default", type="in-memory")
2259+
catalog.create_namespace("ns")
2260+
2261+
_create_table(session_catalog, identifier, {"format-version": "3"}, schema=arrow_table_schema_with_all_timestamp_precisions)
2262+
2263+
with pytest.raises(NotImplementedError, match="Writing V3 is not yet supported"):
2264+
catalog.create_table(
2265+
"ns.table1", schema=arrow_table_schema_with_all_timestamp_precisions, properties={"format-version": "3"}
2266+
)
2267+
2268+
with pytest.raises(
2269+
UnsupportedPyArrowTypeException, match=re.escape("Column 'timestamp_ns' has an unsupported type: timestamp[ns]")
2270+
):
2271+
_create_table(
2272+
session_catalog, identifier, {"format-version": "2"}, schema=arrow_table_schema_with_all_timestamp_precisions
2273+
)
22642274

22652275

22662276
@pytest.mark.parametrize("format_version", [1, 2])
@@ -2281,14 +2291,7 @@ def test_stage_only_delete(
22812291
original_count = len(tbl.scan().to_arrow())
22822292
assert original_count == 3
22832293

2284-
files_to_delete = []
2285-
for file_task in tbl.scan().plan_files():
2286-
files_to_delete.append(file_task.file)
2287-
assert len(files_to_delete) > 0
2288-
2289-
with tbl.transaction() as txn:
2290-
with txn.update_snapshot(branch=None).delete() as delete:
2291-
delete.delete_by_predicate(EqualTo("int", 9))
2294+
tbl.delete("int = 9", branch=None)
22922295

22932296
# a new delete snapshot is added
22942297
snapshots = tbl.snapshots()
@@ -2298,16 +2301,11 @@ def test_stage_only_delete(
22982301
assert len(tbl.scan().to_arrow()) == original_count
22992302

23002303
# Write to main branch
2301-
with tbl.transaction() as txn:
2302-
with txn.update_snapshot().fast_append() as fast_append:
2303-
for data_file in _dataframe_to_data_files(
2304-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2305-
):
2306-
fast_append.append_data_file(data_file=data_file)
2304+
tbl.append(arrow_table_with_null)
23072305

23082306
# Main ref has changed
23092307
assert current_snapshot != tbl.metadata.current_snapshot_id
2310-
assert len(tbl.scan().to_arrow()) == 3
2308+
assert len(tbl.scan().to_arrow()) == 6
23112309
snapshots = tbl.snapshots()
23122310
assert len(snapshots) == 3
23132311

@@ -2327,7 +2325,7 @@ def test_stage_only_delete(
23272325

23282326
@pytest.mark.integration
23292327
@pytest.mark.parametrize("format_version", [1, 2])
2330-
def test_stage_only_fast_append(
2328+
def test_stage_only_append(
23312329
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
23322330
) -> None:
23332331
identifier = f"default.test_stage_only_fast_append_files_v{format_version}"
@@ -2340,12 +2338,7 @@ def test_stage_only_fast_append(
23402338
assert original_count == 3
23412339

23422340
# Write to staging branch
2343-
with tbl.transaction() as txn:
2344-
with txn.update_snapshot(branch=None).fast_append() as fast_append:
2345-
for data_file in _dataframe_to_data_files(
2346-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2347-
):
2348-
fast_append.append_data_file(data_file=data_file)
2341+
tbl.append(arrow_table_with_null, branch=None)
23492342

23502343
# Main ref has not changed and data is not yet appended
23512344
assert current_snapshot == tbl.metadata.current_snapshot_id
@@ -2355,12 +2348,7 @@ def test_stage_only_fast_append(
23552348
assert len(snapshots) == 2
23562349

23572350
# Write to main branch
2358-
with tbl.transaction() as txn:
2359-
with txn.update_snapshot().fast_append() as fast_append:
2360-
for data_file in _dataframe_to_data_files(
2361-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2362-
):
2363-
fast_append.append_data_file(data_file=data_file)
2351+
tbl.append(arrow_table_with_null)
23642352

23652353
# Main ref has changed
23662354
assert current_snapshot != tbl.metadata.current_snapshot_id
@@ -2382,119 +2370,49 @@ def test_stage_only_fast_append(
23822370
assert parent_snapshot_id == [None, current_snapshot, current_snapshot]
23832371

23842372

2385-
2386-
@pytest.mark.integration
2387-
@pytest.mark.parametrize("format_version", [1, 2])
2388-
def test_stage_only_merge_append(
2389-
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2390-
) -> None:
2391-
identifier = f"default.test_stage_only_merge_append_files_v{format_version}"
2392-
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2393-
2394-
current_snapshot = tbl.metadata.current_snapshot_id
2395-
assert current_snapshot is not None
2396-
2397-
original_count = len(tbl.scan().to_arrow())
2398-
assert original_count == 3
2399-
2400-
with tbl.transaction() as txn:
2401-
with txn.update_snapshot(branch=None).merge_append() as merge_append:
2402-
for data_file in _dataframe_to_data_files(
2403-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2404-
):
2405-
merge_append.append_data_file(data_file=data_file)
2406-
2407-
# Main ref has not changed and data is not yet appended
2408-
assert current_snapshot == tbl.metadata.current_snapshot_id
2409-
assert len(tbl.scan().to_arrow()) == original_count
2410-
2411-
# There should be a new staged snapshot
2412-
snapshots = tbl.snapshots()
2413-
assert len(snapshots) == 2
2414-
2415-
# Write to main branch
2416-
with tbl.transaction() as txn:
2417-
with txn.update_snapshot().fast_append() as fast_append:
2418-
for data_file in _dataframe_to_data_files(
2419-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2420-
):
2421-
fast_append.append_data_file(data_file=data_file)
2422-
2423-
# Main ref has changed
2424-
assert current_snapshot != tbl.metadata.current_snapshot_id
2425-
assert len(tbl.scan().to_arrow()) == 6
2426-
snapshots = tbl.snapshots()
2427-
assert len(snapshots) == 3
2428-
2429-
rows = spark.sql(
2430-
f"""
2431-
SELECT operation, parent_id
2432-
FROM {identifier}.snapshots
2433-
ORDER BY committed_at ASC
2434-
"""
2435-
).collect()
2436-
operations = [row.operation for row in rows]
2437-
parent_snapshot_id = [row.parent_id for row in rows]
2438-
assert operations == ["append", "append", "append"]
2439-
# both subsequent parent id should be the first snapshot id
2440-
assert parent_snapshot_id == [None, current_snapshot, current_snapshot]
2441-
2442-
24432373
@pytest.mark.integration
24442374
@pytest.mark.parametrize("format_version", [1, 2])
24452375
def test_stage_only_overwrite_files(
24462376
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
24472377
) -> None:
24482378
identifier = f"default.test_stage_only_overwrite_files_v{format_version}"
24492379
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2380+
first_snapshot = tbl.metadata.current_snapshot_id
24502381

2451-
current_snapshot = tbl.metadata.current_snapshot_id
2452-
assert current_snapshot is not None
2382+
# duplicate data with a new insert
2383+
tbl.append(arrow_table_with_null)
24532384

2385+
second_snapshot = tbl.metadata.current_snapshot_id
2386+
assert second_snapshot is not None
24542387
original_count = len(tbl.scan().to_arrow())
2455-
assert original_count == 3
2388+
assert original_count == 6
24562389

2457-
files_to_delete = []
2458-
for file_task in tbl.scan().plan_files():
2459-
files_to_delete.append(file_task.file)
2460-
assert len(files_to_delete) > 0
2461-
2462-
with tbl.transaction() as txn:
2463-
with txn.update_snapshot(branch=None).overwrite() as overwrite:
2464-
for data_file in _dataframe_to_data_files(
2465-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2466-
):
2467-
overwrite.append_data_file(data_file=data_file)
2468-
overwrite.delete_data_file(files_to_delete[0])
2469-
2470-
assert current_snapshot == tbl.metadata.current_snapshot_id
2390+
# write to non-main branch
2391+
tbl.overwrite(arrow_table_with_null, branch=None)
2392+
assert second_snapshot == tbl.metadata.current_snapshot_id
24712393
assert len(tbl.scan().to_arrow()) == original_count
24722394
snapshots = tbl.snapshots()
2473-
assert len(snapshots) == 2
2395+
# overwrite will create 2 snapshots
2396+
assert len(snapshots) == 4
24742397

2475-
# Write to main branch
2476-
with tbl.transaction() as txn:
2477-
with txn.update_snapshot().fast_append() as fast_append:
2478-
for data_file in _dataframe_to_data_files(
2479-
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2480-
):
2481-
fast_append.append_data_file(data_file=data_file)
2398+
# Write to main branch again
2399+
tbl.append(arrow_table_with_null)
24822400

24832401
# Main ref has changed
2484-
assert current_snapshot != tbl.metadata.current_snapshot_id
2485-
assert len(tbl.scan().to_arrow()) == 6
2402+
assert second_snapshot != tbl.metadata.current_snapshot_id
2403+
assert len(tbl.scan().to_arrow()) == 9
24862404
snapshots = tbl.snapshots()
2487-
assert len(snapshots) == 3
2405+
assert len(snapshots) == 5
24882406

24892407
rows = spark.sql(
24902408
f"""
2491-
SELECT operation, parent_id
2409+
SELECT operation, parent_id, snapshot_id
24922410
FROM {identifier}.snapshots
24932411
ORDER BY committed_at ASC
24942412
"""
24952413
).collect()
24962414
operations = [row.operation for row in rows]
24972415
parent_snapshot_id = [row.parent_id for row in rows]
2498-
assert operations == ["append", "overwrite", "append"]
2499-
# both subsequent parent id should be the first snapshot id
2500-
assert parent_snapshot_id == [None, current_snapshot, current_snapshot]
2416+
assert operations == ["append", "append", "delete", "append", "append"]
2417+
2418+
assert parent_snapshot_id == [None, first_snapshot, second_snapshot, second_snapshot, second_snapshot]

tests/table/test_upsert.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,67 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None:
770770
{"id": 1, "name": "Alicia"},
771771
{"id": 2, "name": "Bob"},
772772
]
773+
774+
775+
def test_stage_only_upsert(catalog: Catalog) -> None:
776+
identifier = "default.test_stage_only_dynamic_partition_overwrite_files"
777+
_drop_table(catalog, identifier)
778+
779+
schema = Schema(
780+
NestedField(1, "city", StringType(), required=True),
781+
NestedField(2, "inhabitants", IntegerType(), required=True),
782+
# Mark City as the identifier field, also known as the primary-key
783+
identifier_field_ids=[1],
784+
)
785+
786+
tbl = catalog.create_table(identifier, schema=schema)
787+
788+
arrow_schema = pa.schema(
789+
[
790+
pa.field("city", pa.string(), nullable=False),
791+
pa.field("inhabitants", pa.int32(), nullable=False),
792+
]
793+
)
794+
795+
# Write some data
796+
df = pa.Table.from_pylist(
797+
[
798+
{"city": "Amsterdam", "inhabitants": 921402},
799+
{"city": "San Francisco", "inhabitants": 808988},
800+
{"city": "Drachten", "inhabitants": 45019},
801+
{"city": "Paris", "inhabitants": 2103000},
802+
],
803+
schema=arrow_schema,
804+
)
805+
806+
tbl.append(df.slice(0, 1))
807+
current_snapshot = tbl.metadata.current_snapshot_id
808+
assert current_snapshot is not None
809+
810+
original_count = len(tbl.scan().to_arrow())
811+
assert original_count == 1
812+
813+
# write to staging snapshot
814+
upd = tbl.upsert(df, branch = None)
815+
assert upd.rows_updated == 0
816+
assert upd.rows_inserted == 3
817+
818+
assert current_snapshot == tbl.metadata.current_snapshot_id
819+
assert len(tbl.scan().to_arrow()) == original_count
820+
snapshots = tbl.snapshots()
821+
assert len(snapshots) == 2
822+
823+
# Write to main ref
824+
tbl.append(df.slice(1, 1))
825+
# Main ref has changed
826+
assert current_snapshot != tbl.metadata.current_snapshot_id
827+
assert len(tbl.scan().to_arrow()) == 2
828+
snapshots = tbl.snapshots()
829+
assert len(snapshots) == 3
830+
831+
sorted_snapshots = sorted(tbl.snapshots(), key=lambda s: s.timestamp_ms)
832+
operations = [snapshot.summary.operation.value if snapshot.summary else None for snapshot in sorted_snapshots]
833+
parent_snapshot_id = [snapshot.parent_snapshot_id for snapshot in sorted_snapshots]
834+
assert operations == ["append", "append", "append"]
835+
# both subsequent parent id should be the first snapshot id
836+
assert parent_snapshot_id == [None, current_snapshot, current_snapshot]

0 commit comments

Comments
 (0)