|
22 | 22 | from abc import abstractmethod |
23 | 23 | from collections import defaultdict |
24 | 24 | from concurrent.futures import Future |
| 25 | +from datetime import datetime |
25 | 26 | from functools import cached_property |
26 | 27 | from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Set, Tuple |
27 | 28 |
|
|
82 | 83 | ) |
83 | 84 | from pyiceberg.utils.bin_packing import ListPacker |
84 | 85 | from pyiceberg.utils.concurrent import ExecutorFactory |
| 86 | +from pyiceberg.utils.datetime import datetime_to_millis |
85 | 87 | from pyiceberg.utils.properties import property_as_bool, property_as_int |
86 | 88 |
|
87 | 89 | if TYPE_CHECKING: |
@@ -944,13 +946,11 @@ def _get_protected_snapshot_ids(self) -> Set[int]: |
944 | 946 | Returns: |
945 | 947 | Set of protected snapshot IDs to exclude from expiration. |
946 | 948 | """ |
947 | | - protected_ids: Set[int] = set() |
948 | | - |
949 | | - for ref in self._transaction.table_metadata.refs.values(): |
950 | | - if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]: |
951 | | - protected_ids.add(ref.snapshot_id) |
952 | | - |
953 | | - return protected_ids |
| 949 | + return { |
| 950 | + ref.snapshot_id |
| 951 | + for ref in self._transaction.table_metadata.refs.values() |
| 952 | + if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH] |
| 953 | + } |
954 | 954 |
|
955 | 955 | def by_id(self, snapshot_id: int) -> ExpireSnapshots: |
956 | 956 | """ |
@@ -988,18 +988,19 @@ def by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots": |
988 | 988 | self.by_id(snapshot_id) |
989 | 989 | return self |
990 | 990 |
|
991 | | - def older_than(self, timestamp_ms: int) -> "ExpireSnapshots": |
| 991 | + def older_than(self, dt: datetime) -> "ExpireSnapshots": |
992 | 992 | """ |
993 | 993 | Expire all unprotected snapshots with a timestamp older than a given value. |
994 | 994 |
|
995 | 995 | Args: |
996 | | - timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired. |
| 996 | + dt (datetime): Only snapshots with datetime < this value will be expired. |
997 | 997 |
|
998 | 998 | Returns: |
999 | 999 | This for method chaining. |
1000 | 1000 | """ |
1001 | 1001 | protected_ids = self._get_protected_snapshot_ids() |
| 1002 | + expire_from = datetime_to_millis(dt) |
1002 | 1003 | for snapshot in self._transaction.table_metadata.snapshots: |
1003 | | - if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids: |
| 1004 | + if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids: |
1004 | 1005 | self._snapshot_ids_to_expire.add(snapshot.snapshot_id) |
1005 | 1006 | return self |
0 commit comments