Skip to content

Commit 311c442

Browse files
committed
feat: update maintenance features with deduplication and retention strategies
1 parent 0e6d45c commit 311c442

File tree

9 files changed

+121
-171
lines changed

9 files changed

+121
-171
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ htmlcov
5050
pyiceberg/avro/decoder_fast.c
5151
pyiceberg/avro/*.html
5252
pyiceberg/avro/*.so
53+
ruff.toml

mkdocs/docs/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ maintenance.expire_snapshots_with_retention_policy(
11641164
)
11651165
```
11661166

1167-
#### Use Cases
1167+
#### Deduplication Use Cases
11681168

11691169
- **Operational Resilience**: Always keep recent snapshots for rollback.
11701170
- **Space Reclamation**: Remove old, unneeded snapshots.

pyiceberg/table/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ def inspect(self) -> InspectTable:
907907
@property
908908
def maintenance(self) -> MaintenanceTable:
909909
"""Return the MaintenanceTable object for maintenance.
910+
910911
Returns:
911912
MaintenanceTable object based on this Table.
912913
"""

pyiceberg/table/maintenance.py

Lines changed: 38 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from __future__ import annotations
1818

1919
import logging
20-
from typing import TYPE_CHECKING, List, Optional, Set, Union
20+
from typing import TYPE_CHECKING, Any, List, Optional, Set, Union
2121

22-
from pyiceberg.manifest import DataFile
23-
from pyiceberg.utils.concurrent import ThreadPoolExecutor
22+
from pyiceberg.manifest import DataFile, ManifestFile
23+
from pyiceberg.utils.concurrent import ThreadPoolExecutor # type: ignore[attr-defined]
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -52,7 +52,7 @@ def expire_snapshot_by_id(self, snapshot_id: int) -> None:
5252
"""
5353
with self.tbl.transaction() as txn:
5454
# Check if snapshot exists
55-
if txn.table_metadata.snapshot_by_id(snapshot_id) is None:
55+
if not any(snapshot.snapshot_id == snapshot_id for snapshot in txn.table_metadata.snapshots):
5656
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
5757

5858
# Check if snapshot is protected
@@ -97,7 +97,7 @@ def expire_snapshots_older_than(self, timestamp_ms: int) -> None:
9797
"""
9898
# First check if there are any snapshots to expire to avoid unnecessary transactions
9999
protected_ids = self._get_protected_snapshot_ids(self.tbl.metadata)
100-
snapshots_to_expire = []
100+
snapshots_to_expire: List[int] = []
101101

102102
for snapshot in self.tbl.metadata.snapshots:
103103
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
@@ -110,10 +110,7 @@ def expire_snapshots_older_than(self, timestamp_ms: int) -> None:
110110
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=snapshots_to_expire),))
111111

112112
def expire_snapshots_older_than_with_retention(
113-
self,
114-
timestamp_ms: int,
115-
retain_last_n: Optional[int] = None,
116-
min_snapshots_to_keep: Optional[int] = None
113+
self, timestamp_ms: int, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
117114
) -> None:
118115
"""Expire all unprotected snapshots with a timestamp older than a given value, with retention strategies.
119116
@@ -123,9 +120,7 @@ def expire_snapshots_older_than_with_retention(
123120
min_snapshots_to_keep: Minimum number of snapshots to keep in total.
124121
"""
125122
snapshots_to_expire = self._get_snapshots_to_expire_with_retention(
126-
timestamp_ms=timestamp_ms,
127-
retain_last_n=retain_last_n,
128-
min_snapshots_to_keep=min_snapshots_to_keep
123+
timestamp_ms=timestamp_ms, retain_last_n=retain_last_n, min_snapshots_to_keep=min_snapshots_to_keep
129124
)
130125

131126
if snapshots_to_expire:
@@ -147,25 +142,21 @@ def retain_last_n_snapshots(self, n: int) -> None:
147142
raise ValueError("Number of snapshots to retain must be at least 1")
148143

149144
protected_ids = self._get_protected_snapshot_ids(self.tbl.metadata)
150-
145+
151146
# Sort snapshots by timestamp (most recent first)
152-
sorted_snapshots = sorted(
153-
self.tbl.metadata.snapshots,
154-
key=lambda s: s.timestamp_ms,
155-
reverse=True
156-
)
157-
147+
sorted_snapshots = sorted(self.tbl.metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
148+
158149
# Keep the last N snapshots and all protected ones
159150
snapshots_to_keep = set()
160151
snapshots_to_keep.update(protected_ids)
161-
152+
162153
# Add the N most recent snapshots
163154
for i, snapshot in enumerate(sorted_snapshots):
164155
if i < n:
165156
snapshots_to_keep.add(snapshot.snapshot_id)
166-
157+
167158
# Find snapshots to expire
168-
snapshots_to_expire = []
159+
snapshots_to_expire: List[int] = []
169160
for snapshot in self.tbl.metadata.snapshots:
170161
if snapshot.snapshot_id not in snapshots_to_keep:
171162
snapshots_to_expire.append(snapshot.snapshot_id)
@@ -177,10 +168,7 @@ def retain_last_n_snapshots(self, n: int) -> None:
177168
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=snapshots_to_expire),))
178169

179170
def _get_snapshots_to_expire_with_retention(
180-
self,
181-
timestamp_ms: Optional[int] = None,
182-
retain_last_n: Optional[int] = None,
183-
min_snapshots_to_keep: Optional[int] = None
171+
self, timestamp_ms: Optional[int] = None, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
184172
) -> List[int]:
185173
"""Get snapshots to expire considering retention strategies.
186174
@@ -193,54 +181,46 @@ def _get_snapshots_to_expire_with_retention(
193181
List of snapshot IDs to expire.
194182
"""
195183
protected_ids = self._get_protected_snapshot_ids(self.tbl.metadata)
196-
184+
197185
# Sort snapshots by timestamp (most recent first)
198-
sorted_snapshots = sorted(
199-
self.tbl.metadata.snapshots,
200-
key=lambda s: s.timestamp_ms,
201-
reverse=True
202-
)
203-
186+
sorted_snapshots = sorted(self.tbl.metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
187+
204188
# Start with all snapshots that could be expired
205189
candidates_for_expiration = []
206190
snapshots_to_keep = set(protected_ids)
207-
191+
208192
# Apply retain_last_n constraint
209193
if retain_last_n is not None:
210194
for i, snapshot in enumerate(sorted_snapshots):
211195
if i < retain_last_n:
212196
snapshots_to_keep.add(snapshot.snapshot_id)
213-
197+
214198
# Apply timestamp constraint
215199
for snapshot in self.tbl.metadata.snapshots:
216-
if (snapshot.snapshot_id not in snapshots_to_keep and
217-
(timestamp_ms is None or snapshot.timestamp_ms < timestamp_ms)):
200+
if snapshot.snapshot_id not in snapshots_to_keep and (timestamp_ms is None or snapshot.timestamp_ms < timestamp_ms):
218201
candidates_for_expiration.append(snapshot)
219-
202+
220203
# Sort candidates by timestamp (oldest first) for potential expiration
221204
candidates_for_expiration.sort(key=lambda s: s.timestamp_ms)
222-
205+
223206
# Apply min_snapshots_to_keep constraint
224207
total_snapshots = len(self.tbl.metadata.snapshots)
225-
snapshots_to_expire = []
226-
208+
snapshots_to_expire: List[int] = []
209+
227210
for candidate in candidates_for_expiration:
228211
# Check if expiring this snapshot would violate min_snapshots_to_keep
229212
remaining_after_expiration = total_snapshots - len(snapshots_to_expire) - 1
230-
213+
231214
if min_snapshots_to_keep is None or remaining_after_expiration >= min_snapshots_to_keep:
232215
snapshots_to_expire.append(candidate.snapshot_id)
233216
else:
234217
# Stop expiring to maintain minimum count
235218
break
236-
219+
237220
return snapshots_to_expire
238221

239222
def expire_snapshots_with_retention_policy(
240-
self,
241-
timestamp_ms: Optional[int] = None,
242-
retain_last_n: Optional[int] = None,
243-
min_snapshots_to_keep: Optional[int] = None
223+
self, timestamp_ms: Optional[int] = None, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
244224
) -> List[int]:
245225
"""Comprehensive snapshot expiration with multiple retention strategies.
246226
@@ -266,13 +246,13 @@ def expire_snapshots_with_retention_policy(
266246
Examples:
267247
# Keep last 5 snapshots regardless of age
268248
maintenance.expire_snapshots_with_retention_policy(retain_last_n=5)
269-
249+
270250
# Expire snapshots older than timestamp but keep at least 3 total
271251
maintenance.expire_snapshots_with_retention_policy(
272252
timestamp_ms=1234567890000,
273253
min_snapshots_to_keep=3
274254
)
275-
255+
276256
# Combined policy: expire old snapshots but keep last 10 and at least 5 total
277257
maintenance.expire_snapshots_with_retention_policy(
278258
timestamp_ms=1234567890000,
@@ -282,14 +262,12 @@ def expire_snapshots_with_retention_policy(
282262
"""
283263
if retain_last_n is not None and retain_last_n < 1:
284264
raise ValueError("retain_last_n must be at least 1")
285-
265+
286266
if min_snapshots_to_keep is not None and min_snapshots_to_keep < 1:
287267
raise ValueError("min_snapshots_to_keep must be at least 1")
288268

289269
snapshots_to_expire = self._get_snapshots_to_expire_with_retention(
290-
timestamp_ms=timestamp_ms,
291-
retain_last_n=retain_last_n,
292-
min_snapshots_to_keep=min_snapshots_to_keep
270+
timestamp_ms=timestamp_ms, retain_last_n=retain_last_n, min_snapshots_to_keep=min_snapshots_to_keep
293271
)
294272

295273
if snapshots_to_expire:
@@ -326,12 +304,10 @@ def _get_all_datafiles(
326304
target_file_path: Optional[str] = None,
327305
parallel: bool = True,
328306
) -> List[DataFile]:
329-
"""
330-
Collect all DataFiles in the table, optionally filtering by file path.
331-
"""
307+
"""Collect all DataFiles in the table, optionally filtering by file path."""
332308
datafiles: List[DataFile] = []
333309

334-
def process_manifest(manifest) -> list[DataFile]:
310+
def process_manifest(manifest: ManifestFile) -> list[DataFile]:
335311
found: list[DataFile] = []
336312
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
337313
if hasattr(entry, "data_file"):
@@ -356,7 +332,7 @@ def process_manifest(manifest) -> list[DataFile]:
356332
# Only current snapshot
357333
for chunk in self.tbl.inspect.data_files().to_pylist():
358334
file_path = chunk.get("file_path")
359-
partition = chunk.get("partition", {})
335+
partition: dict[str, Any] = dict(chunk.get("partition", {}) or {})
360336
if target_file_path is None or file_path == target_file_path:
361337
datafiles.append(DataFile(file_path=file_path, partition=partition))
362338
return datafiles
@@ -389,16 +365,16 @@ def deduplicate_data_files(
389365
seen = {}
390366
duplicates = []
391367
for df in all_datafiles:
392-
partition = dict(df.partition) if hasattr(df.partition, "items") else df.partition
368+
partition: dict[str, Any] = df.partition.to_dict() if hasattr(df.partition, "to_dict") else {}
393369
if scan_all_partitions:
394-
key = (df.file_path, tuple(sorted(partition.items())) if partition else None)
370+
key = (df.file_path, tuple(sorted(partition.items())) if partition else ())
395371
else:
396-
key = df.file_path
372+
key = (df.file_path, ()) # Add an empty tuple for partition when scan_all_partitions is False
397373
if key in seen:
398374
duplicates.append(df)
399375
else:
400376
seen[key] = df
401-
to_remove = duplicates
377+
to_remove = duplicates # type: ignore[assignment]
402378

403379
# Normalize to DataFile objects
404380
normalized_to_remove: List[DataFile] = []

pyiceberg/table/update/snapshot.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,7 @@
8282
from pyiceberg.utils.properties import property_as_bool, property_as_int
8383

8484
if TYPE_CHECKING:
85-
pass
86-
87-
88-
from pyiceberg.table.metadata import Snapshot
85+
from pyiceberg.table import Transaction
8986

9087

9188
def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str:

ruff.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,7 @@ exclude = [
3434
".svn",
3535
".tox",
3636
".venv",
37-
"__pypackages__",
38-
"_build",
39-
"buck-out",
40-
"build",
41-
"dist",
42-
"node_modules",
43-
"venv",
37+
"vendor",
4438
]
4539

4640
# Ignore _all_ violations.

tests/expressions/test_literals.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,21 @@ def test_invalid_decimal_conversions() -> None:
744744
def test_invalid_string_conversions() -> None:
745745
assert_invalid_conversions(
746746
literal("abc"),
747+
[
748+
BooleanType(),
749+
IntegerType(),
750+
LongType(),
751+
FloatType(),
752+
DoubleType(),
753+
DateType(),
754+
TimeType(),
755+
TimestampType(),
756+
TimestamptzType(),
757+
DecimalType(9, 2),
758+
UUIDType(),
759+
FixedType(1),
760+
BinaryType(),
761+
],
747762
)
748763

749764

0 commit comments

Comments
 (0)