|
10 | 10 | from contextlib import contextmanager, suppress |
11 | 11 | from copy import copy |
12 | 12 | from dataclasses import dataclass |
| 13 | +from datetime import datetime, timedelta, timezone |
13 | 14 | from functools import cached_property, reduce |
14 | 15 | from typing import TYPE_CHECKING, Any |
15 | 16 | from uuid import uuid4 |
|
24 | 25 | from tqdm.auto import tqdm |
25 | 26 |
|
26 | 27 | from datachain.cache import Cache |
| 28 | +from datachain.checkpoint import Checkpoint, CheckpointStatus |
27 | 29 | from datachain.client import Client |
28 | 30 | from datachain.dataset import ( |
29 | 31 | DATASET_PREFIX, |
|
72 | 74 |
|
73 | 75 | DEFAULT_DATASET_DIR = "dataset" |
74 | 76 |
|
75 | | -TTL_INT = 4 * 60 * 60 |
| 77 | +CHECKPOINTS_TTL = 4 * 60 * 60 |
76 | 78 |
|
77 | 79 | INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing" |
78 | 80 | DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset" |
@@ -2094,3 +2096,64 @@ def index( |
2094 | 2096 | only_index=True, |
2095 | 2097 | ): |
2096 | 2098 | pass |
| 2099 | + |
| 2100 | + def cleanup_checkpoints(self, ttl_seconds: int | None = None) -> int: |
| 2101 | + """Clean up outdated checkpoints and their associated UDF tables. |
| 2102 | +
|
| 2103 | + Algorithm: |
| 2104 | + 1. Atomically mark expired checkpoints as EXPIRED (single query) — |
| 2105 | + prevents running jobs from using them via find_checkpoint. |
| 2106 | + Then collect all EXPIRED checkpoints and determine which run |
| 2107 | + groups have no remaining ACTIVE checkpoints. |
| 2108 | + 2. Remove output/partial tables using exact names from checkpoints |
| 2109 | + 3. For fully-inactive run groups: remove shared input tables |
| 2110 | + 4. Mark checkpoints as DELETED |
| 2111 | + """ |
| 2112 | + if ttl_seconds is None: |
| 2113 | + ttl_seconds = CHECKPOINTS_TTL |
| 2114 | + |
| 2115 | + ttl_threshold = datetime.now(timezone.utc) - timedelta(seconds=ttl_seconds) |
| 2116 | + |
| 2117 | + # Expire + collect everything in one metastore call |
| 2118 | + checkpoints, inactive_group_ids = self.metastore.expire_checkpoints( |
| 2119 | + ttl_threshold |
| 2120 | + ) |
| 2121 | + if not checkpoints: |
| 2122 | + return 0 |
| 2123 | + |
| 2124 | + logger.info( |
| 2125 | + "Cleaning %d expired checkpoints across %d inactive run groups", |
| 2126 | + len(checkpoints), |
| 2127 | + len(inactive_group_ids), |
| 2128 | + ) |
| 2129 | + |
| 2130 | + # Remove output/partial tables using exact names from checkpoints |
| 2131 | + output_tables = [ch.table_name for ch in checkpoints] |
| 2132 | + if output_tables: |
| 2133 | + logger.info( |
| 2134 | + "Removing %d UDF output tables: %s", len(output_tables), output_tables |
| 2135 | + ) |
| 2136 | + self.warehouse.cleanup_tables(output_tables) |
| 2137 | + |
| 2138 | + # Shared input tables — only when entire run group is inactive |
| 2139 | + for group_id in inactive_group_ids: |
| 2140 | + input_tables = self.warehouse.db.list_tables( |
| 2141 | + pattern=Checkpoint.input_table_pattern(group_id) |
| 2142 | + ) |
| 2143 | + if input_tables: |
| 2144 | + logger.info( |
| 2145 | + "Removing %d shared input tables: %s", |
| 2146 | + len(input_tables), |
| 2147 | + input_tables, |
| 2148 | + ) |
| 2149 | + self.warehouse.cleanup_tables(input_tables) |
| 2150 | + |
| 2151 | + self.metastore.update_checkpoints( |
| 2152 | + [ch.id for ch in checkpoints], status=CheckpointStatus.DELETED |
| 2153 | + ) |
| 2154 | + |
| 2155 | + logger.info( |
| 2156 | + "Checkpoint cleanup complete: removed %d checkpoints", |
| 2157 | + len(checkpoints), |
| 2158 | + ) |
| 2159 | + return len(checkpoints) |
0 commit comments