Skip to content

Commit 8dfa038

Browse files
committed
Closes:
(1) #2130 with addition of the new `deduplicate_data_files` function to the `MaintenanceTable` class. (2) #2151 with the removal of the errant member variable from the `ManageSnapshots` class. (3) #2150 by adding the additional functions to be at parity with the Java API.
1 parent fe73a34 commit 8dfa038

File tree

7 files changed

+594
-115
lines changed

7 files changed

+594
-115
lines changed

pyiceberg/table/__init__.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from pyiceberg.schema import Schema
8181
from pyiceberg.table.inspect import InspectTable
8282
from pyiceberg.table.locations import LocationProvider, load_location_provider
83+
from pyiceberg.table.maintenance import MaintenanceTable
8384
from pyiceberg.table.metadata import (
8485
INITIAL_SEQUENCE_NUMBER,
8586
TableMetadata,
@@ -115,12 +116,7 @@
115116
update_table_metadata,
116117
)
117118
from pyiceberg.table.update.schema import UpdateSchema
118-
from pyiceberg.table.update.snapshot import (
119-
ManageSnapshots,
120-
UpdateSnapshot,
121-
_FastAppendFiles,
122-
ExpireSnapshots
123-
)
119+
from pyiceberg.table.update.snapshot import ExpireSnapshots, ManageSnapshots, UpdateSnapshot, _FastAppendFiles
124120
from pyiceberg.table.update.spec import UpdateSpec
125121
from pyiceberg.table.update.statistics import UpdateStatistics
126122
from pyiceberg.transforms import IdentityTransform
@@ -908,6 +904,14 @@ def inspect(self) -> InspectTable:
908904
"""
909905
return InspectTable(self)
910906

907+
@property
908+
def maintenance(self) -> MaintenanceTable:
909+
"""Return the MaintenanceTable object for maintenance.
910+
Returns:
911+
MaintenanceTable object based on this Table.
912+
"""
913+
return MaintenanceTable(self)
914+
911915
def refresh(self) -> Table:
912916
"""Refresh the current table metadata.
913917
@@ -1079,15 +1083,6 @@ def manage_snapshots(self) -> ManageSnapshots:
10791083
ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
10801084
"""
10811085
return ManageSnapshots(transaction=Transaction(self, autocommit=True))
1082-
1083-
def expire_snapshots(self) -> ExpireSnapshots:
1084-
"""
1085-
Shorthand to run expire snapshots by id or by a timestamp.
1086-
1087-
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
1088-
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
1089-
"""
1090-
return ExpireSnapshots(transaction=Transaction(self, autocommit=True))
10911086

10921087
def update_statistics(self) -> UpdateStatistics:
10931088
"""

pyiceberg/table/inspect.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
from __future__ import annotations
1818

1919
from datetime import datetime, timezone
20-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple
20+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
2121

2222
from pyiceberg.conversions import from_bytes
2323
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, PartitionFieldSummary
2424
from pyiceberg.partitioning import PartitionSpec
2525
from pyiceberg.table.snapshots import Snapshot, ancestors_of
2626
from pyiceberg.types import PrimitiveType
27-
from pyiceberg.utils.concurrent import ExecutorFactory
2827
from pyiceberg.utils.singleton import _convert_to_hashable_type
2928

3029
if TYPE_CHECKING:
@@ -645,15 +644,19 @@ def data_files(self, snapshot_id: Optional[int] = None) -> "pa.Table":
645644
def delete_files(self, snapshot_id: Optional[int] = None) -> "pa.Table":
646645
return self._files(snapshot_id, {DataFileContent.POSITION_DELETES, DataFileContent.EQUALITY_DELETES})
647646

648-
def all_manifests(self) -> "pa.Table":
647+
def all_manifests(self, snapshots: Optional[Union[list[Snapshot], list[int]]] = None) -> "pa.Table":
649648
import pyarrow as pa
650649

651-
snapshots = self.tbl.snapshots()
650+
# coerce into snapshot objects if users passes in snapshot ids
651+
if snapshots is not None:
652+
if isinstance(snapshots[0], int):
653+
snapshots = [
654+
snapshot
655+
for snapshot_id in snapshots
656+
if (snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id)) is not None
657+
]
658+
else:
659+
snapshots = self.tbl.snapshots()
660+
652661
if not snapshots:
653662
return pa.Table.from_pylist([], schema=self._get_all_manifests_schema())
654-
655-
executor = ExecutorFactory.get_or_create()
656-
manifests_by_snapshots: Iterator["pa.Table"] = executor.map(
657-
lambda args: self._generate_manifests_table(*args), [(snapshot, True) for snapshot in snapshots]
658-
)
659-
return pa.concat_tables(manifests_by_snapshots)

pyiceberg/table/maintenance.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import logging
20+
from concurrent.futures import ThreadPoolExecutor
21+
from typing import TYPE_CHECKING, List, Optional, Set, Union
22+
23+
from pyiceberg.manifest import DataFile
24+
from pyiceberg.utils.concurrent import ThreadPoolExecutor
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
if TYPE_CHECKING:
30+
from pyiceberg.table import Table
31+
from pyiceberg.table.metadata import TableMetadata
32+
33+
34+
class MaintenanceTable:
35+
tbl: Table
36+
37+
def __init__(self, tbl: Table) -> None:
38+
self.tbl = tbl
39+
40+
try:
41+
import pyarrow as pa # noqa
42+
except ModuleNotFoundError as e:
43+
raise ModuleNotFoundError("For metadata operations PyArrow needs to be installed") from e
44+
45+
def expire_snapshot_by_id(self, snapshot_id: int) -> None:
46+
"""Expire a single snapshot by its ID.
47+
48+
Args:
49+
snapshot_id: The ID of the snapshot to expire.
50+
51+
Raises:
52+
ValueError: If the snapshot does not exist or is protected.
53+
"""
54+
with self.tbl.transaction() as txn:
55+
# Check if snapshot exists
56+
if txn.table_metadata.snapshot_by_id(snapshot_id) is None:
57+
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
58+
59+
# Check if snapshot is protected
60+
protected_ids = self._get_protected_snapshot_ids(txn.table_metadata)
61+
if snapshot_id in protected_ids:
62+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
63+
64+
# Remove the snapshot
65+
from pyiceberg.table.update import RemoveSnapshotsUpdate
66+
67+
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=[snapshot_id]),))
68+
69+
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> None:
70+
"""Expire multiple snapshots by their IDs.
71+
72+
Args:
73+
snapshot_ids: List of snapshot IDs to expire.
74+
75+
Raises:
76+
ValueError: If any snapshot does not exist or is protected.
77+
"""
78+
with self.tbl.transaction() as txn:
79+
protected_ids = self._get_protected_snapshot_ids(txn.table_metadata)
80+
81+
# Validate all snapshots before expiring any
82+
for snapshot_id in snapshot_ids:
83+
if txn.table_metadata.snapshot_by_id(snapshot_id) is None:
84+
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
85+
if snapshot_id in protected_ids:
86+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
87+
88+
# Remove all snapshots
89+
from pyiceberg.table.update import RemoveSnapshotsUpdate
90+
91+
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids),))
92+
93+
def expire_snapshots_older_than(self, timestamp_ms: int) -> None:
94+
"""Expire all unprotected snapshots with a timestamp older than a given value.
95+
96+
Args:
97+
timestamp_ms: Only snapshots with timestamp_ms < this value will be expired.
98+
"""
99+
# First check if there are any snapshots to expire to avoid unnecessary transactions
100+
protected_ids = self._get_protected_snapshot_ids(self.tbl.metadata)
101+
snapshots_to_expire = []
102+
103+
for snapshot in self.tbl.metadata.snapshots:
104+
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
105+
snapshots_to_expire.append(snapshot.snapshot_id)
106+
107+
if snapshots_to_expire:
108+
with self.tbl.transaction() as txn:
109+
from pyiceberg.table.update import RemoveSnapshotsUpdate
110+
111+
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=snapshots_to_expire),))
112+
113+
def _get_protected_snapshot_ids(self, table_metadata: TableMetadata) -> Set[int]:
114+
"""Get the IDs of protected snapshots.
115+
116+
These are the HEAD snapshots of all branches and all tagged snapshots.
117+
These ids are to be excluded from expiration.
118+
119+
Args:
120+
table_metadata: The table metadata to check for protected snapshots.
121+
122+
Returns:
123+
Set of protected snapshot IDs to exclude from expiration.
124+
"""
125+
from pyiceberg.table.refs import SnapshotRefType
126+
127+
protected_ids: Set[int] = set()
128+
for ref in table_metadata.refs.values():
129+
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
130+
protected_ids.add(ref.snapshot_id)
131+
return protected_ids
132+
133+
def _get_all_datafiles(
134+
self,
135+
scan_all_snapshots: bool = False,
136+
target_file_path: Optional[str] = None,
137+
parallel: bool = True,
138+
) -> List[DataFile]:
139+
"""
140+
Collect all DataFiles in the table, optionally filtering by file path.
141+
"""
142+
datafiles: List[DataFile] = []
143+
144+
def process_manifest(manifest) -> list[DataFile]:
145+
found: list[DataFile] = []
146+
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
147+
if hasattr(entry, "data_file"):
148+
df = entry.data_file
149+
if target_file_path is None or df.file_path == target_file_path:
150+
found.append(df)
151+
return found
152+
153+
if scan_all_snapshots:
154+
manifests = []
155+
for snapshot in self.tbl.snapshots():
156+
manifests.extend(snapshot.manifests(io=self.tbl.io))
157+
if parallel:
158+
with ThreadPoolExecutor() as executor:
159+
results = executor.map(process_manifest, manifests)
160+
for res in results:
161+
datafiles.extend(res)
162+
else:
163+
for manifest in manifests:
164+
datafiles.extend(process_manifest(manifest))
165+
else:
166+
# Only current snapshot
167+
for chunk in self.tbl.inspect.data_files().to_pylist():
168+
file_path = chunk.get("file_path")
169+
partition = chunk.get("partition", {})
170+
if target_file_path is None or file_path == target_file_path:
171+
datafiles.append(DataFile(file_path=file_path, partition=partition))
172+
return datafiles
173+
174+
def deduplicate_data_files(
175+
self,
176+
scan_all_partitions: bool = True,
177+
scan_all_snapshots: bool = False,
178+
to_remove: Optional[List[Union[DataFile, str]]] = None,
179+
parallel: bool = True,
180+
) -> List[DataFile]:
181+
"""
182+
Remove duplicate data files from an Iceberg table.
183+
184+
Args:
185+
scan_all_partitions: If True, scan all partitions for duplicates (uses file_path+partition as key).
186+
scan_all_snapshots: If True, scan all snapshots for duplicates, otherwise only current snapshot.
187+
to_remove: List of DataFile objects or file path strings to remove. If None, auto-detect duplicates.
188+
parallel: If True, parallelize manifest traversal.
189+
190+
Returns:
191+
List of removed DataFile objects.
192+
"""
193+
removed: List[DataFile] = []
194+
195+
# Determine what to remove
196+
if to_remove is None:
197+
# Auto-detect duplicates
198+
all_datafiles = self._get_all_datafiles(scan_all_snapshots=scan_all_snapshots, parallel=parallel)
199+
seen = {}
200+
duplicates = []
201+
for df in all_datafiles:
202+
partition = dict(df.partition) if hasattr(df.partition, "items") else df.partition
203+
if scan_all_partitions:
204+
key = (df.file_path, tuple(sorted(partition.items())) if partition else None)
205+
else:
206+
key = df.file_path
207+
if key in seen:
208+
duplicates.append(df)
209+
else:
210+
seen[key] = df
211+
to_remove = duplicates
212+
213+
# Normalize to DataFile objects
214+
normalized_to_remove: List[DataFile] = []
215+
all_datafiles = self._get_all_datafiles(scan_all_snapshots=scan_all_snapshots, parallel=parallel)
216+
for item in to_remove or []:
217+
if isinstance(item, DataFile):
218+
normalized_to_remove.append(item)
219+
elif isinstance(item, str):
220+
# Remove all DataFiles with this file_path
221+
for df in all_datafiles:
222+
if df.file_path == item:
223+
normalized_to_remove.append(df)
224+
else:
225+
raise ValueError(f"Unsupported type in to_remove: {type(item)}")
226+
227+
# Remove the DataFiles
228+
for df in normalized_to_remove:
229+
self.tbl.transaction().update_snapshot().overwrite().delete_data_file(df).commit()
230+
removed.append(df)
231+
232+
return removed

0 commit comments

Comments
 (0)