Skip to content

Commit 7eadf3c

Browse files
committed
update tests
1 parent 6cc9122 commit 7eadf3c

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

dev/provision.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,51 @@
389389
VALUES (4)
390390
"""
391391
)
392+
393+
spark.sql(
394+
f"""
395+
CREATE OR REPLACE TABLE {catalog_name}.default.test_table_rollback_to_snapshot_id (
396+
timestamp int,
397+
number integer
398+
)
399+
USING iceberg
400+
TBLPROPERTIES (
401+
'format-version'='2'
402+
);
403+
"""
404+
)
405+
406+
spark.sql(
407+
f"""
408+
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
409+
VALUES (200, 1)
410+
"""
411+
)
412+
413+
spark.sql(
414+
f"""
415+
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
416+
VALUES (202, 2)
417+
"""
418+
)
419+
420+
spark.sql(
421+
f"""
422+
DELETE FROM {catalog_name}.default.test_table_rollback_to_snapshot_id
423+
WHERE number = 2
424+
"""
425+
)
426+
427+
spark.sql(
428+
f"""
429+
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
430+
VALUES (204, 3)
431+
"""
432+
)
433+
434+
spark.sql(
435+
f"""
436+
INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id
437+
VALUES (206, 4)
438+
"""
439+
)

tests/integration/test_snapshot_operations.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,37 +49,38 @@ def test_manage_snapshots_context_manager(catalog: Catalog) -> None:
4949
tbl = catalog.load_table(identifier)
5050
assert len(tbl.history()) > 3
5151
current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore
52-
rollback_snapshot_id = tbl.history()[-4].snapshot_id
52+
expected_snapshot_id = tbl.history()[-4].snapshot_id
5353
with tbl.manage_snapshots() as ms:
5454
ms.create_tag(snapshot_id=current_snapshot_id, tag_name="testing")
55-
ms.rollback_to_snapshot(snapshot_id=rollback_snapshot_id)
55+
ms.set_current_snapshot(snapshot_id=expected_snapshot_id)
56+
ms.create_tag(snapshot_id=expected_snapshot_id, tag_name="testing2")
5657
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore
5758
assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id
58-
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch")
59+
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch")
60+
assert tbl.metadata.refs["testing2"].snapshot_id == expected_snapshot_id
5961

6062

6163
@pytest.mark.integration
6264
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
6365
def test_rollback_to_snapshot(catalog: Catalog) -> None:
64-
identifier = "default.test_table_snapshot_operations"
66+
identifier = "default.test_table_rollback_to_snapshot_id"
6567
tbl = catalog.load_table(identifier)
6668
assert len(tbl.history()) > 3
67-
rollback_snapshot_id = tbl.history()[-3].snapshot_id
69+
rollback_snapshot_id = tbl.current_snapshot().parent_snapshot_id # type: ignore
6870
current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore
69-
tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit()
71+
tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() # type: ignore
7072
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore
7173
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch")
7274

7375

7476
@pytest.mark.integration
7577
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
7678
def test_rollback_to_timestamp(catalog: Catalog) -> None:
77-
identifier = "default.test_table_snapshot_operations"
79+
identifier = "default.test_table_rollback_to_snapshot_id"
7880
tbl = catalog.load_table(identifier)
79-
assert len(tbl.history()) > 3
80-
current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore
81-
timestamp = tbl.history()[-2].timestamp_ms
82-
expected_snapshot_id = tbl.history()[-3].snapshot_id
81+
assert len(tbl.history()) > 4
82+
current_snapshot_id, timestamp = tbl.history()[-1].snapshot_id, tbl.history()[-1].timestamp_ms
83+
expected_snapshot_id = tbl.snapshot_by_id(current_snapshot_id).parent_snapshot_id # type: ignore
8384
# not inclusive of rollback_timestamp
8485
tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp).commit()
8586
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore
@@ -107,7 +108,7 @@ def test_set_current_snapshot_with_ref_name(catalog: Catalog) -> None:
107108
assert len(tbl.history()) > 3
108109
current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore
109110
expected_snapshot_id = tbl.history()[-3].snapshot_id
110-
tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag19").commit()
111-
tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag19").commit()
111+
tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag").commit()
112+
tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag").commit()
112113
assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore
113114
assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch")

0 commit comments

Comments
 (0)