Skip to content

Commit 2fc6758

Browse files
committed
SQLite Connection Cleanup: Added proper cleanup of SQLAlchemy engine connections using engine.dispose() in test fixtures
Test Resource Management: Added try/finally blocks to ensure cleanup happens even if tests fail Catalog Connection Handling: Modified both the iceberg_catalog and prepopulated_table fixtures to properly clean up database connections Mock Catalog Cleanup: Added cleanup for tests that replace the table catalog with mock objects
1 parent 3a5c8e4 commit 2fc6758

File tree

3 files changed

+306
-84
lines changed

3 files changed

+306
-84
lines changed

pyiceberg/table/maintenance.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,20 @@ def expire_snapshots_with_retention_policy(
213213
This method provides a unified interface for snapshot expiration with various
214214
retention policies to ensure operational resilience while allowing space reclamation.
215215
216+
The method will use table properties as defaults if they are set:
217+
- history.expire.max-snapshot-age-ms: Default for timestamp_ms if not provided
218+
- history.expire.min-snapshots-to-keep: Default for min_snapshots_to_keep if not provided
219+
- history.expire.max-ref-age-ms: Used for ref expiration (branches/tags)
220+
216221
Args:
217222
timestamp_ms: Only snapshots with timestamp_ms < this value will be considered for expiration.
218-
If None, all snapshots are candidates (subject to other constraints).
223+
If None, will use history.expire.max-snapshot-age-ms table property if set.
219224
retain_last_n: Always keep the last N snapshots regardless of age.
220225
Useful when regular snapshot creation occurs and users want to keep
221226
the last few for rollback purposes.
222227
min_snapshots_to_keep: Minimum number of snapshots to keep in total.
223-
Acts as a guardrail to prevent aggressive expiration logic
224-
from removing too many snapshots.
228+
Acts as a guardrail to prevent aggressive expiration logic.
229+
If None, will use history.expire.min-snapshots-to-keep table property if set.
225230
226231
Returns:
227232
List of snapshot IDs that were expired.
@@ -230,48 +235,62 @@ def expire_snapshots_with_retention_policy(
230235
ValueError: If retain_last_n or min_snapshots_to_keep is less than 1.
231236
232237
Examples:
233-
# Keep last 5 snapshots regardless of age
234-
maintenance.expire_snapshots_with_retention_policy(retain_last_n=5)
235-
236-
# Expire snapshots older than timestamp but keep at least 3 total
237-
maintenance.expire_snapshots_with_retention_policy(
238-
timestamp_ms=1234567890000,
239-
min_snapshots_to_keep=3
240-
)
238+
# Use table property defaults
239+
maintenance.expire_snapshots_with_retention_policy()
241240
242-
# Combined policy: expire old snapshots but keep last 10 and at least 5 total
241+
# Override defaults with explicit values
243242
maintenance.expire_snapshots_with_retention_policy(
244243
timestamp_ms=1234567890000,
245244
retain_last_n=10,
246245
min_snapshots_to_keep=5
247246
)
248247
"""
248+
# Get default values from table properties
249+
default_max_age, default_min_snapshots, _ = self._get_expiration_properties()
250+
251+
# Use defaults from table properties if not explicitly provided
252+
if timestamp_ms is None:
253+
timestamp_ms = default_max_age
254+
255+
if min_snapshots_to_keep is None:
256+
min_snapshots_to_keep = default_min_snapshots
257+
258+
# If no expiration criteria are provided, don't expire anything
259+
if timestamp_ms is None and retain_last_n is None and min_snapshots_to_keep is None:
260+
return
261+
249262
if retain_last_n is not None and retain_last_n < 1:
250263
raise ValueError("retain_last_n must be at least 1")
251264

252265
if min_snapshots_to_keep is not None and min_snapshots_to_keep < 1:
253266
raise ValueError("min_snapshots_to_keep must be at least 1")
254267

255268
snapshots_to_expire = self._get_snapshots_to_expire_with_retention(
256-
timestamp_ms=timestamp_ms, retain_last_n=retain_last_n, min_snapshots_to_keep=min_snapshots_to_keep
269+
timestamp_ms=timestamp_ms,
270+
retain_last_n=retain_last_n,
271+
min_snapshots_to_keep=min_snapshots_to_keep
257272
)
258273

259274
if snapshots_to_expire:
260275
self._expire_snapshots_by_ids(snapshots_to_expire)
261276

262-
def _get_protected_snapshot_ids(self) -> Set[int]:
277+
def _get_protected_snapshot_ids(self, table_metadata: Optional[TableMetadata] = None) -> Set[int]:
263278
"""Get the IDs of protected snapshots.
264279
265280
These are the HEAD snapshots of all branches and all tagged snapshots.
266281
These ids are to be excluded from expiration.
267282
268283
Args:
269-
table_metadata: The table metadata to check for protected snapshots.
284+
table_metadata: Optional table metadata to check for protected snapshots.
285+
If not provided, uses the table's current metadata.
270286
271287
Returns:
272288
Set of protected snapshot IDs to exclude from expiration.
273289
"""
274-
return set(self.tbl.inspect.refs()["snapshot_id"].to_pylist())
290+
# Prefer provided metadata, fall back to current table metadata
291+
metadata = table_metadata or self.tbl.metadata
292+
refs = metadata.refs if metadata else {}
293+
return {ref.snapshot_id for ref in refs.values()}
275294

276295
def _get_all_datafiles(self) -> List[DataFile]:
277296
"""Collect all DataFiles in the current snapshot only."""
@@ -359,3 +378,22 @@ def deduplicate_data_files(self) -> List[DataFile]:
359378
self.tbl = self.tbl.refresh()
360379

361380
return removed
381+
382+
def _get_expiration_properties(self) -> tuple[Optional[int], Optional[int], Optional[int]]:
383+
"""Get the default expiration properties from table properties.
384+
385+
Returns:
386+
Tuple of (max_snapshot_age_ms, min_snapshots_to_keep, max_ref_age_ms)
387+
"""
388+
properties = self.tbl.properties
389+
390+
max_snapshot_age_ms = properties.get("history.expire.max-snapshot-age-ms")
391+
max_snapshot_age = int(max_snapshot_age_ms) if max_snapshot_age_ms is not None else None
392+
393+
min_snapshots = properties.get("history.expire.min-snapshots-to-keep")
394+
min_snapshots_to_keep = int(min_snapshots) if min_snapshots is not None else None
395+
396+
max_ref_age = properties.get("history.expire.max-ref-age-ms")
397+
max_ref_age_ms = int(max_ref_age) if max_ref_age is not None else None
398+
399+
return max_snapshot_age, min_snapshots_to_keep, max_ref_age_ms

tests/table/test_dedup_data_file_filepaths.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import uuid
1919
from pathlib import Path
20-
from typing import List, Set
20+
from typing import List, Set, Generator
2121

2222
import pyarrow as pa
2323
import pyarrow.parquet as pq
@@ -30,10 +30,16 @@
3030

3131

3232
@pytest.fixture
33-
def iceberg_catalog(tmp_path: Path) -> InMemoryCatalog:
33+
def iceberg_catalog(tmp_path: Path) -> Generator[InMemoryCatalog, None, None]:
3434
catalog = InMemoryCatalog("test.in_memory.catalog", warehouse=tmp_path.absolute().as_posix())
3535
catalog.create_namespace("default")
36-
return catalog
36+
yield catalog
37+
# Clean up SQLAlchemy engine connections
38+
if hasattr(catalog, 'engine'):
39+
try:
40+
catalog.engine.dispose()
41+
except Exception:
42+
pass
3743

3844

3945
@pytest.fixture
@@ -43,7 +49,7 @@ def dupe_data_file_path(tmp_path: Path) -> Path:
4349

4450

4551
@pytest.fixture
46-
def prepopulated_table(iceberg_catalog: InMemoryCatalog, dupe_data_file_path: Path) -> Table:
52+
def prepopulated_table(iceberg_catalog: InMemoryCatalog, dupe_data_file_path: Path) -> Generator[Table, None, None]:
4753
identifier = "default.test_table"
4854
try:
4955
iceberg_catalog.drop_table(identifier)
@@ -85,7 +91,14 @@ def prepopulated_table(iceberg_catalog: InMemoryCatalog, dupe_data_file_path: Pa
8591
tx2.add_files([str(dupe_data_file_path)], check_duplicate_files=False)
8692
tx2.commit_transaction()
8793

88-
return table
94+
yield table
95+
96+
# Cleanup table's catalog connections
97+
if hasattr(table, '_catalog') and hasattr(table._catalog, 'engine'):
98+
try:
99+
table._catalog.engine.dispose()
100+
except Exception:
101+
pass
89102

90103

91104
def test_overwrite_removes_only_selected_datafile(prepopulated_table: Table, dupe_data_file_path: Path) -> None:
@@ -112,11 +125,21 @@ def test_get_all_datafiles_current_snapshot(prepopulated_table: Table, dupe_data
112125

113126

114127
def test_get_all_datafiles_all_snapshots(prepopulated_table: Table, dupe_data_file_path: Path) -> None:
115-
mt = MaintenanceTable(tbl=prepopulated_table)
128+
try:
129+
mt = MaintenanceTable(tbl=prepopulated_table)
116130

117-
datafiles: List[DataFile] = mt._get_all_datafiles()
118-
file_paths: Set[str] = {df.file_path.split("/")[-1] for df in datafiles}
119-
assert dupe_data_file_path.name in file_paths
131+
datafiles: List[DataFile] = mt._get_all_datafiles()
132+
file_paths: Set[str] = {df.file_path.split("/")[-1] for df in datafiles}
133+
assert dupe_data_file_path.name in file_paths
134+
finally:
135+
# Ensure catalog connections are properly closed
136+
if hasattr(prepopulated_table, '_catalog'):
137+
catalog = prepopulated_table._catalog
138+
if hasattr(catalog, '_connection') and catalog._connection is not None:
139+
try:
140+
catalog._connection.close()
141+
except Exception:
142+
pass
120143

121144

122145
def test_deduplicate_data_files_removes_duplicates_in_current_snapshot(

0 commit comments

Comments
 (0)