Skip to content

Commit 8c906d2

Browse files
committed
refactor: streamline data file retrieval in MaintenanceTable and enhance deduplication tests
1 parent a9a01ee commit 8c906d2

File tree

3 files changed

+11
-123
lines changed

3 files changed

+11
-123
lines changed

pyiceberg/table/maintenance.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -294,27 +294,9 @@ def _get_protected_snapshot_ids(self, table_metadata: TableMetadata) -> Set[int]
294294

295295
def _get_all_datafiles(self) -> List[DataFile]:
296296
"""Collect all DataFiles in the current snapshot only."""
297-
datafiles: List[DataFile] = []
298-
299-
current_snapshot = self.tbl.current_snapshot()
300-
if not current_snapshot:
301-
return datafiles
302-
303-
def process_manifest(manifest: ManifestFile) -> list[DataFile]:
304-
found: list[DataFile] = []
305-
for entry in manifest.fetch_manifest_entry(io=self.tbl.io, discard_deleted=True):
306-
if hasattr(entry, "data_file"):
307-
found.append(entry.data_file)
308-
return found
309-
310-
# Scan only the current snapshot's manifests
311-
manifests = current_snapshot.manifests(io=self.tbl.io)
312-
with ThreadPoolExecutor() as executor:
313-
results = executor.map(process_manifest, manifests)
314-
for res in results:
315-
datafiles.extend(res)
316-
317-
return datafiles
297+
data_file_structs = self.tbl.inspect.data_files()
298+
data_files = [DataFile(df) for df in data_file_structs]
299+
return data_files
318300

319301
def deduplicate_data_files(self) -> List[DataFile]:
320302
"""

pyiceberg/table/update/snapshot.py

Lines changed: 1 addition & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -903,104 +903,4 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
903903
Returns:
904904
This for method chaining
905905
"""
906-
return self._remove_ref_snapshot(ref_name=branch_name)
907-
908-
909-
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
910-
"""
911-
Expire snapshots by ID.
912-
913-
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
914-
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
915-
Pending changes are applied on commit.
916-
"""
917-
918-
_snapshot_ids_to_expire: Set[int] = set()
919-
_updates: Tuple[TableUpdate, ...] = ()
920-
_requirements: Tuple[TableRequirement, ...] = ()
921-
922-
def _commit(self) -> UpdatesAndRequirements:
923-
"""
924-
Commit the staged updates and requirements.
925-
926-
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
927-
928-
Returns:
929-
Tuple of updates and requirements to be committed,
930-
as required by the calling parent apply functions.
931-
"""
932-
# Remove any protected snapshot IDs from the set to expire, just in case
933-
protected_ids = self._get_protected_snapshot_ids()
934-
self._snapshot_ids_to_expire -= protected_ids
935-
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
936-
self._updates += (update,)
937-
return self._updates, self._requirements
938-
939-
def _get_protected_snapshot_ids(self) -> Set[int]:
940-
"""
941-
Get the IDs of protected snapshots.
942-
943-
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
944-
945-
Returns:
946-
Set of protected snapshot IDs to exclude from expiration.
947-
"""
948-
protected_ids: Set[int] = set()
949-
950-
for ref in self._transaction.table_metadata.refs.values():
951-
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
952-
protected_ids.add(ref.snapshot_id)
953-
954-
return protected_ids
955-
956-
def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
957-
"""
958-
Expire a snapshot by its ID.
959-
960-
This will mark the snapshot for expiration.
961-
962-
Args:
963-
snapshot_id (int): The ID of the snapshot to expire.
964-
Returns:
965-
This for method chaining.
966-
"""
967-
if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None:
968-
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
969-
970-
if snapshot_id in self._get_protected_snapshot_ids():
971-
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
972-
973-
self._snapshot_ids_to_expire.add(snapshot_id)
974-
975-
return self
976-
977-
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots":
978-
"""
979-
Expire multiple snapshots by their IDs.
980-
981-
This will mark the snapshots for expiration.
982-
983-
Args:
984-
snapshot_ids (List[int]): List of snapshot IDs to expire.
985-
Returns:
986-
This for method chaining.
987-
"""
988-
for snapshot_id in snapshot_ids:
989-
self.expire_snapshot_by_id(snapshot_id)
990-
return self
991-
992-
def expire_snapshots_older_than(self, timestamp_ms: int) -> "ExpireSnapshots":
993-
"""
994-
Expire all unprotected snapshots with a timestamp older than a given value.
995-
996-
Args:
997-
timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired.
998-
999-
Returns:
1000-
This for method chaining.
1001-
"""
1002-
protected_ids = self._get_protected_snapshot_ids()
1003-
for snapshot in self._transaction.table_metadata.snapshots:
1004-
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
1005-
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1006-
return self
906+
return self._remove_ref_snapshot(ref_name=branch_name)

tests/table/test_dedup_data_file_filepaths.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def test_deduplicate_data_files_removes_duplicates_in_current_snapshot(
124124
) -> None:
125125
mt = MaintenanceTable(tbl=prepopulated_table)
126126

127+
print("=== Before deduplication ===")
128+
check_data_files(prepopulated_table)
129+
127130
all_datafiles: List[DataFile] = mt._get_all_datafiles()
128131
file_names: List[str] = [os.path.basename(df.file_path) for df in all_datafiles]
129132
# There should be more than one reference before deduplication
@@ -132,8 +135,11 @@ def test_deduplicate_data_files_removes_duplicates_in_current_snapshot(
132135
)
133136
removed: List[DataFile] = mt.deduplicate_data_files()
134137

138+
print("=== After deduplication ===")
139+
check_data_files(prepopulated_table)
140+
135141
all_datafiles_after: List[DataFile] = mt._get_all_datafiles()
136142
file_names_after: List[str] = [os.path.basename(df.file_path) for df in all_datafiles_after]
137143
# Only one reference should remain after deduplication
138144
assert file_names_after.count(dupe_data_file_path.name) == 1
139-
assert all(isinstance(df, DataFile) for df in removed)
145+
assert all(isinstance(df, DataFile) for df in removed)

0 commit comments

Comments
 (0)