17
17
from __future__ import annotations
18
18
19
19
import logging
20
- from typing import TYPE_CHECKING , List , Optional , Set , Union
20
+ from typing import TYPE_CHECKING , Any , List , Optional , Set , Union
21
21
22
- from pyiceberg .manifest import DataFile
23
- from pyiceberg .utils .concurrent import ThreadPoolExecutor
22
+ from pyiceberg .manifest import DataFile , ManifestFile
23
+ from pyiceberg .utils .concurrent import ThreadPoolExecutor # type: ignore[attr-defined]
24
24
25
25
logger = logging .getLogger (__name__ )
26
26
@@ -52,7 +52,7 @@ def expire_snapshot_by_id(self, snapshot_id: int) -> None:
52
52
"""
53
53
with self .tbl .transaction () as txn :
54
54
# Check if snapshot exists
55
- if txn .table_metadata .snapshot_by_id ( snapshot_id ) is None :
55
+ if not any ( snapshot . snapshot_id == snapshot_id for snapshot in txn .table_metadata .snapshots ) :
56
56
raise ValueError (f"Snapshot with ID { snapshot_id } does not exist." )
57
57
58
58
# Check if snapshot is protected
@@ -97,7 +97,7 @@ def expire_snapshots_older_than(self, timestamp_ms: int) -> None:
97
97
"""
98
98
# First check if there are any snapshots to expire to avoid unnecessary transactions
99
99
protected_ids = self ._get_protected_snapshot_ids (self .tbl .metadata )
100
- snapshots_to_expire = []
100
+ snapshots_to_expire : List [ int ] = []
101
101
102
102
for snapshot in self .tbl .metadata .snapshots :
103
103
if snapshot .timestamp_ms < timestamp_ms and snapshot .snapshot_id not in protected_ids :
@@ -110,10 +110,7 @@ def expire_snapshots_older_than(self, timestamp_ms: int) -> None:
110
110
txn ._apply ((RemoveSnapshotsUpdate (snapshot_ids = snapshots_to_expire ),))
111
111
112
112
def expire_snapshots_older_than_with_retention (
113
- self ,
114
- timestamp_ms : int ,
115
- retain_last_n : Optional [int ] = None ,
116
- min_snapshots_to_keep : Optional [int ] = None
113
+ self , timestamp_ms : int , retain_last_n : Optional [int ] = None , min_snapshots_to_keep : Optional [int ] = None
117
114
) -> None :
118
115
"""Expire all unprotected snapshots with a timestamp older than a given value, with retention strategies.
119
116
@@ -123,9 +120,7 @@ def expire_snapshots_older_than_with_retention(
123
120
min_snapshots_to_keep: Minimum number of snapshots to keep in total.
124
121
"""
125
122
snapshots_to_expire = self ._get_snapshots_to_expire_with_retention (
126
- timestamp_ms = timestamp_ms ,
127
- retain_last_n = retain_last_n ,
128
- min_snapshots_to_keep = min_snapshots_to_keep
123
+ timestamp_ms = timestamp_ms , retain_last_n = retain_last_n , min_snapshots_to_keep = min_snapshots_to_keep
129
124
)
130
125
131
126
if snapshots_to_expire :
@@ -147,25 +142,21 @@ def retain_last_n_snapshots(self, n: int) -> None:
147
142
raise ValueError ("Number of snapshots to retain must be at least 1" )
148
143
149
144
protected_ids = self ._get_protected_snapshot_ids (self .tbl .metadata )
150
-
145
+
151
146
# Sort snapshots by timestamp (most recent first)
152
- sorted_snapshots = sorted (
153
- self .tbl .metadata .snapshots ,
154
- key = lambda s : s .timestamp_ms ,
155
- reverse = True
156
- )
157
-
147
+ sorted_snapshots = sorted (self .tbl .metadata .snapshots , key = lambda s : s .timestamp_ms , reverse = True )
148
+
158
149
# Keep the last N snapshots and all protected ones
159
150
snapshots_to_keep = set ()
160
151
snapshots_to_keep .update (protected_ids )
161
-
152
+
162
153
# Add the N most recent snapshots
163
154
for i , snapshot in enumerate (sorted_snapshots ):
164
155
if i < n :
165
156
snapshots_to_keep .add (snapshot .snapshot_id )
166
-
157
+
167
158
# Find snapshots to expire
168
- snapshots_to_expire = []
159
+ snapshots_to_expire : List [ int ] = []
169
160
for snapshot in self .tbl .metadata .snapshots :
170
161
if snapshot .snapshot_id not in snapshots_to_keep :
171
162
snapshots_to_expire .append (snapshot .snapshot_id )
@@ -177,10 +168,7 @@ def retain_last_n_snapshots(self, n: int) -> None:
177
168
txn ._apply ((RemoveSnapshotsUpdate (snapshot_ids = snapshots_to_expire ),))
178
169
179
170
def _get_snapshots_to_expire_with_retention (
180
- self ,
181
- timestamp_ms : Optional [int ] = None ,
182
- retain_last_n : Optional [int ] = None ,
183
- min_snapshots_to_keep : Optional [int ] = None
171
+ self , timestamp_ms : Optional [int ] = None , retain_last_n : Optional [int ] = None , min_snapshots_to_keep : Optional [int ] = None
184
172
) -> List [int ]:
185
173
"""Get snapshots to expire considering retention strategies.
186
174
@@ -193,54 +181,46 @@ def _get_snapshots_to_expire_with_retention(
193
181
List of snapshot IDs to expire.
194
182
"""
195
183
protected_ids = self ._get_protected_snapshot_ids (self .tbl .metadata )
196
-
184
+
197
185
# Sort snapshots by timestamp (most recent first)
198
- sorted_snapshots = sorted (
199
- self .tbl .metadata .snapshots ,
200
- key = lambda s : s .timestamp_ms ,
201
- reverse = True
202
- )
203
-
186
+ sorted_snapshots = sorted (self .tbl .metadata .snapshots , key = lambda s : s .timestamp_ms , reverse = True )
187
+
204
188
# Start with all snapshots that could be expired
205
189
candidates_for_expiration = []
206
190
snapshots_to_keep = set (protected_ids )
207
-
191
+
208
192
# Apply retain_last_n constraint
209
193
if retain_last_n is not None :
210
194
for i , snapshot in enumerate (sorted_snapshots ):
211
195
if i < retain_last_n :
212
196
snapshots_to_keep .add (snapshot .snapshot_id )
213
-
197
+
214
198
# Apply timestamp constraint
215
199
for snapshot in self .tbl .metadata .snapshots :
216
- if (snapshot .snapshot_id not in snapshots_to_keep and
217
- (timestamp_ms is None or snapshot .timestamp_ms < timestamp_ms )):
200
+ if snapshot .snapshot_id not in snapshots_to_keep and (timestamp_ms is None or snapshot .timestamp_ms < timestamp_ms ):
218
201
candidates_for_expiration .append (snapshot )
219
-
202
+
220
203
# Sort candidates by timestamp (oldest first) for potential expiration
221
204
candidates_for_expiration .sort (key = lambda s : s .timestamp_ms )
222
-
205
+
223
206
# Apply min_snapshots_to_keep constraint
224
207
total_snapshots = len (self .tbl .metadata .snapshots )
225
- snapshots_to_expire = []
226
-
208
+ snapshots_to_expire : List [ int ] = []
209
+
227
210
for candidate in candidates_for_expiration :
228
211
# Check if expiring this snapshot would violate min_snapshots_to_keep
229
212
remaining_after_expiration = total_snapshots - len (snapshots_to_expire ) - 1
230
-
213
+
231
214
if min_snapshots_to_keep is None or remaining_after_expiration >= min_snapshots_to_keep :
232
215
snapshots_to_expire .append (candidate .snapshot_id )
233
216
else :
234
217
# Stop expiring to maintain minimum count
235
218
break
236
-
219
+
237
220
return snapshots_to_expire
238
221
239
222
def expire_snapshots_with_retention_policy (
240
- self ,
241
- timestamp_ms : Optional [int ] = None ,
242
- retain_last_n : Optional [int ] = None ,
243
- min_snapshots_to_keep : Optional [int ] = None
223
+ self , timestamp_ms : Optional [int ] = None , retain_last_n : Optional [int ] = None , min_snapshots_to_keep : Optional [int ] = None
244
224
) -> List [int ]:
245
225
"""Comprehensive snapshot expiration with multiple retention strategies.
246
226
@@ -266,13 +246,13 @@ def expire_snapshots_with_retention_policy(
266
246
Examples:
267
247
# Keep last 5 snapshots regardless of age
268
248
maintenance.expire_snapshots_with_retention_policy(retain_last_n=5)
269
-
249
+
270
250
# Expire snapshots older than timestamp but keep at least 3 total
271
251
maintenance.expire_snapshots_with_retention_policy(
272
252
timestamp_ms=1234567890000,
273
253
min_snapshots_to_keep=3
274
254
)
275
-
255
+
276
256
# Combined policy: expire old snapshots but keep last 10 and at least 5 total
277
257
maintenance.expire_snapshots_with_retention_policy(
278
258
timestamp_ms=1234567890000,
@@ -282,14 +262,12 @@ def expire_snapshots_with_retention_policy(
282
262
"""
283
263
if retain_last_n is not None and retain_last_n < 1 :
284
264
raise ValueError ("retain_last_n must be at least 1" )
285
-
265
+
286
266
if min_snapshots_to_keep is not None and min_snapshots_to_keep < 1 :
287
267
raise ValueError ("min_snapshots_to_keep must be at least 1" )
288
268
289
269
snapshots_to_expire = self ._get_snapshots_to_expire_with_retention (
290
- timestamp_ms = timestamp_ms ,
291
- retain_last_n = retain_last_n ,
292
- min_snapshots_to_keep = min_snapshots_to_keep
270
+ timestamp_ms = timestamp_ms , retain_last_n = retain_last_n , min_snapshots_to_keep = min_snapshots_to_keep
293
271
)
294
272
295
273
if snapshots_to_expire :
@@ -326,12 +304,10 @@ def _get_all_datafiles(
326
304
target_file_path : Optional [str ] = None ,
327
305
parallel : bool = True ,
328
306
) -> List [DataFile ]:
329
- """
330
- Collect all DataFiles in the table, optionally filtering by file path.
331
- """
307
+ """Collect all DataFiles in the table, optionally filtering by file path."""
332
308
datafiles : List [DataFile ] = []
333
309
334
- def process_manifest (manifest ) -> list [DataFile ]:
310
+ def process_manifest (manifest : ManifestFile ) -> list [DataFile ]:
335
311
found : list [DataFile ] = []
336
312
for entry in manifest .fetch_manifest_entry (io = self .tbl .io ):
337
313
if hasattr (entry , "data_file" ):
@@ -356,7 +332,7 @@ def process_manifest(manifest) -> list[DataFile]:
356
332
# Only current snapshot
357
333
for chunk in self .tbl .inspect .data_files ().to_pylist ():
358
334
file_path = chunk .get ("file_path" )
359
- partition = chunk .get ("partition" , {})
335
+ partition : dict [ str , Any ] = dict ( chunk .get ("partition" , {}) or {})
360
336
if target_file_path is None or file_path == target_file_path :
361
337
datafiles .append (DataFile (file_path = file_path , partition = partition ))
362
338
return datafiles
@@ -389,16 +365,16 @@ def deduplicate_data_files(
389
365
seen = {}
390
366
duplicates = []
391
367
for df in all_datafiles :
392
- partition = dict ( df .partition ) if hasattr (df .partition , "items " ) else df . partition
368
+ partition : dict [ str , Any ] = df .partition . to_dict ( ) if hasattr (df .partition , "to_dict " ) else {}
393
369
if scan_all_partitions :
394
- key = (df .file_path , tuple (sorted (partition .items ())) if partition else None )
370
+ key = (df .file_path , tuple (sorted (partition .items ())) if partition else () )
395
371
else :
396
- key = df .file_path
372
+ key = ( df .file_path , ()) # Add an empty tuple for partition when scan_all_partitions is False
397
373
if key in seen :
398
374
duplicates .append (df )
399
375
else :
400
376
seen [key ] = df
401
- to_remove = duplicates
377
+ to_remove = duplicates # type: ignore[assignment]
402
378
403
379
# Normalize to DataFile objects
404
380
normalized_to_remove : List [DataFile ] = []
0 commit comments