|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | | -import datetime |
18 | | -from unittest.mock import MagicMock |
| 17 | +import threading |
| 18 | +from datetime import datetime, timedelta |
| 19 | +from typing import Dict |
| 20 | +from unittest.mock import MagicMock, Mock |
19 | 21 | from uuid import uuid4 |
20 | 22 |
|
21 | 23 | import pytest |
22 | 24 |
|
23 | 25 | from pyiceberg.table import CommitTableResponse, Table |
| 26 | +from pyiceberg.table.update.snapshot import ExpireSnapshots |
24 | 27 |
|
25 | 28 |
|
26 | 29 | def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None: |
@@ -143,7 +146,7 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None: |
143 | 146 | table_v2.catalog = MagicMock() |
144 | 147 |
|
145 | 148 | # Attempt to expire all snapshots before a future timestamp (so both are candidates) |
146 | | - future_datetime = datetime.datetime.now() + datetime.timedelta(days=1) |
| 149 | + future_datetime = datetime.now() + timedelta(days=1) |
147 | 150 |
|
148 | 151 | # Mock the catalog's commit_table to return the current metadata (simulate no change) |
149 | 152 | mock_response = CommitTableResponse( |
@@ -223,3 +226,57 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None: |
223 | 226 | assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots |
224 | 227 | assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots |
225 | 228 | assert len(table_v2.metadata.snapshots) == 1 |
| 229 | + |
| 230 | + |
| 231 | +def test_thread_safety_fix() -> None: |
| 232 | + """Test that ExpireSnapshots instances have isolated state.""" |
| 233 | + # Create two ExpireSnapshots instances |
| 234 | + expire1 = ExpireSnapshots(Mock()) |
| 235 | + expire2 = ExpireSnapshots(Mock()) |
| 236 | + |
| 237 | + # Verify they have separate snapshot sets (this was the bug!) |
| 238 | + # Before fix: both would have the same id (shared class attribute) |
| 239 | + # After fix: they should have different ids (separate instance attributes) |
| 240 | + assert id(expire1._snapshot_ids_to_expire) != id(expire2._snapshot_ids_to_expire), ( |
| 241 | + "ExpireSnapshots instances are sharing the same snapshot set - thread safety bug still exists" |
| 242 | + ) |
| 243 | + |
| 244 | + # Test that modifications to one don't affect the other |
| 245 | + expire1._snapshot_ids_to_expire.add(1001) |
| 246 | + expire2._snapshot_ids_to_expire.add(2001) |
| 247 | + |
| 248 | + # Verify no cross-contamination of snapshot IDs |
| 249 | + assert 2001 not in expire1._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances" |
| 250 | + assert 1001 not in expire2._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances" |
| 251 | + |
| 252 | + |
| 253 | +def test_concurrent_operations() -> None: |
| 254 | + """Test concurrent operations with separate ExpireSnapshots instances.""" |
| 255 | + results: Dict[str, set[int]] = {"expire1_snapshots": set(), "expire2_snapshots": set()} |
| 256 | + |
| 257 | + def worker1() -> None: |
| 258 | + expire1 = ExpireSnapshots(Mock()) |
| 259 | + expire1._snapshot_ids_to_expire.update([1001, 1002, 1003]) |
| 260 | + results["expire1_snapshots"] = expire1._snapshot_ids_to_expire.copy() |
| 261 | + |
| 262 | + def worker2() -> None: |
| 263 | + expire2 = ExpireSnapshots(Mock()) |
| 264 | + expire2._snapshot_ids_to_expire.update([2001, 2002, 2003]) |
| 265 | + results["expire2_snapshots"] = expire2._snapshot_ids_to_expire.copy() |
| 266 | + |
| 267 | + # Run both workers concurrently |
| 268 | + thread1 = threading.Thread(target=worker1) |
| 269 | + thread2 = threading.Thread(target=worker2) |
| 270 | + |
| 271 | + thread1.start() |
| 272 | + thread2.start() |
| 273 | + |
| 274 | + thread1.join() |
| 275 | + thread2.join() |
| 276 | + |
| 277 | + # Check for cross-contamination |
| 278 | + expected_1 = {1001, 1002, 1003} |
| 279 | + expected_2 = {2001, 2002, 2003} |
| 280 | + |
| 281 | + assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated" |
| 282 | + assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated" |
0 commit comments