Skip to content

Commit c6ec31f

Browse files
authored
UDF Checkpoints cleanup (#1590)
UDF Checkpoints cleanup
1 parent 3ba4d5c commit c6ec31f

24 files changed

+797
-99
lines changed

docs/commands/gc.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# gc
2+
3+
Garbage collect temporary tables, failed dataset versions, and outdated checkpoints.
4+
5+
## Synopsis
6+
7+
```usage
8+
usage: datachain gc [-h] [-v] [-q] [--checkpoint-ttl CHECKPOINT_TTL]
9+
```
10+
11+
## Description
12+
13+
This command cleans up internal DataChain storage by removing:
14+
15+
- **Temporary tables** created during query execution that were not properly cleaned up (e.g., due to crashes or interrupted operations).
16+
- **Failed dataset versions** that were left in an incomplete or failed state.
17+
- **Outdated checkpoints** and their associated UDF tables that have exceeded the time-to-live (TTL) threshold. See [Checkpoints](../guide/checkpoints.md) for more details.
18+
19+
## Options
20+
21+
* `-h`, `--help` - Show the help message and exit.
22+
* `-v`, `--verbose` - Be verbose.
23+
* `-q`, `--quiet` - Be quiet.
24+
* `--checkpoint-ttl` - Time-to-live for checkpoints in seconds. Checkpoints older than this value are considered outdated and eligible for cleanup. Defaults to 4 hours (14400 seconds).
25+
26+
## Examples
27+
28+
1. Run garbage collection:
29+
```bash
30+
datachain gc
31+
```
32+
33+
2. Run garbage collection with a custom checkpoint TTL of 1 hour:
34+
```bash
35+
datachain gc --checkpoint-ttl 3600
36+
```
37+
38+
Example output:
39+
```
40+
Collecting temporary tables...
41+
Removed 3 temporary tables.
42+
Collecting failed dataset versions...
43+
No failed dataset versions to clean up.
44+
Collecting outdated checkpoints...
45+
Removed 5 outdated checkpoints.
46+
```

docs/commands/index.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ DataChain is a command-line tool for wrangling unstructured AI data at scale. Us
3333
- Cancel running jobs with [`datachain job cancel`](job/cancel.md)
3434

3535
- Check for the clusters available for jobs [`datachain job clusters`](job/clusters.md)
36+
37+
38+
3. **Maintenance**
39+
40+
- Clean up temporary tables, failed versions, and outdated checkpoints with [`datachain gc`](gc.md)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ nav:
9797
- 📖 CLI Reference:
9898
- Overview: commands/index.md
9999
- Commands:
100+
- gc: commands/gc.md
100101
- auth:
101102
- login: commands/auth/login.md
102103
- logout: commands/auth/logout.md

src/datachain/catalog/catalog.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from contextlib import contextmanager, suppress
1111
from copy import copy
1212
from dataclasses import dataclass
13+
from datetime import datetime, timedelta, timezone
1314
from functools import cached_property, reduce
1415
from typing import TYPE_CHECKING, Any
1516
from uuid import uuid4
@@ -24,6 +25,7 @@
2425
from tqdm.auto import tqdm
2526

2627
from datachain.cache import Cache
28+
from datachain.checkpoint import Checkpoint, CheckpointStatus
2729
from datachain.client import Client
2830
from datachain.dataset import (
2931
DATASET_PREFIX,
@@ -72,7 +74,7 @@
7274

7375
DEFAULT_DATASET_DIR = "dataset"
7476

75-
TTL_INT = 4 * 60 * 60
77+
CHECKPOINTS_TTL = 4 * 60 * 60
7678

7779
INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing"
7880
DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
@@ -2094,3 +2096,64 @@ def index(
20942096
only_index=True,
20952097
):
20962098
pass
2099+
2100+
def cleanup_checkpoints(self, ttl_seconds: int | None = None) -> int:
2101+
"""Clean up outdated checkpoints and their associated UDF tables.
2102+
2103+
Algorithm:
2104+
1. Atomically mark expired checkpoints as EXPIRED (single query) —
2105+
prevents running jobs from using them via find_checkpoint.
2106+
Then collect all EXPIRED checkpoints and determine which run
2107+
groups have no remaining ACTIVE checkpoints.
2108+
2. Remove output/partial tables using exact names from checkpoints
2109+
3. For fully-inactive run groups: remove shared input tables
2110+
4. Mark checkpoints as DELETED
2111+
"""
2112+
if ttl_seconds is None:
2113+
ttl_seconds = CHECKPOINTS_TTL
2114+
2115+
ttl_threshold = datetime.now(timezone.utc) - timedelta(seconds=ttl_seconds)
2116+
2117+
# Expire + collect everything in one metastore call
2118+
checkpoints, inactive_group_ids = self.metastore.expire_checkpoints(
2119+
ttl_threshold
2120+
)
2121+
if not checkpoints:
2122+
return 0
2123+
2124+
logger.info(
2125+
"Cleaning %d expired checkpoints across %d inactive run groups",
2126+
len(checkpoints),
2127+
len(inactive_group_ids),
2128+
)
2129+
2130+
# Remove output/partial tables using exact names from checkpoints
2131+
output_tables = [ch.table_name for ch in checkpoints]
2132+
if output_tables:
2133+
logger.info(
2134+
"Removing %d UDF output tables: %s", len(output_tables), output_tables
2135+
)
2136+
self.warehouse.cleanup_tables(output_tables)
2137+
2138+
# Shared input tables — only when entire run group is inactive
2139+
for group_id in inactive_group_ids:
2140+
input_tables = self.warehouse.db.list_tables(
2141+
pattern=Checkpoint.input_table_pattern(group_id)
2142+
)
2143+
if input_tables:
2144+
logger.info(
2145+
"Removing %d shared input tables: %s",
2146+
len(input_tables),
2147+
input_tables,
2148+
)
2149+
self.warehouse.cleanup_tables(input_tables)
2150+
2151+
self.metastore.update_checkpoints(
2152+
[ch.id for ch in checkpoints], status=CheckpointStatus.DELETED
2153+
)
2154+
2155+
logger.info(
2156+
"Checkpoint cleanup complete: removed %d checkpoints",
2157+
len(checkpoints),
2158+
)
2159+
return len(checkpoints)

src/datachain/checkpoint.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import uuid
22
from dataclasses import dataclass
33
from datetime import datetime
4+
from enum import IntEnum
5+
6+
7+
class CheckpointStatus(IntEnum):
8+
ACTIVE = 0
9+
EXPIRED = 1
10+
DELETED = 2
411

512

613
@dataclass
@@ -24,6 +31,34 @@ class Checkpoint:
2431
hash: str
2532
partial: bool
2633
created_at: datetime
34+
status: int = CheckpointStatus.ACTIVE
35+
36+
@staticmethod
37+
def output_table_name(job_id: str, _hash: str) -> str:
38+
"""Final UDF output table. Job-specific, created when UDF completes."""
39+
return f"udf_{job_id}_{_hash}_output"
40+
41+
@staticmethod
42+
def partial_output_table_name(job_id: str, _hash: str) -> str:
43+
"""Partial UDF output table. Temporary, renamed to final on completion."""
44+
return f"udf_{job_id}_{_hash}_output_partial"
45+
46+
@staticmethod
47+
def input_table_name(group_id: str, _hash: str) -> str:
48+
"""Shared UDF input table. Scoped to run group, reused across jobs."""
49+
return f"udf_{group_id}_{_hash}_input"
50+
51+
@staticmethod
52+
def input_table_pattern(group_id: str) -> str:
53+
"""LIKE pattern for finding all input tables in a run group."""
54+
return f"udf_{group_id}_%_input"
55+
56+
@property
57+
def table_name(self) -> str:
58+
"""UDF output table name associated with this checkpoint."""
59+
if self.partial:
60+
return self.partial_output_table_name(self.job_id, self.hash)
61+
return self.output_table_name(self.job_id, self.hash)
2762

2863
@classmethod
2964
def parse(
@@ -33,11 +68,13 @@ def parse(
3368
_hash: str,
3469
partial: bool,
3570
created_at: datetime,
71+
status: int = CheckpointStatus.ACTIVE,
3672
) -> "Checkpoint":
3773
return cls(
3874
str(id),
3975
job_id,
4076
_hash,
4177
bool(partial),
4278
created_at,
79+
int(status),
4380
)

src/datachain/cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def handle_command(args, catalog, client_config) -> int:
9797
"index": lambda: handle_index_command(args, catalog),
9898
"completion": lambda: handle_completion_command(args),
9999
"clear-cache": lambda: clear_cache(catalog),
100-
"gc": lambda: garbage_collect(catalog),
100+
"gc": lambda: garbage_collect(catalog, checkpoint_ttl=args.checkpoint_ttl),
101101
"auth": lambda: process_auth_cli_args(args),
102102
"job": lambda: process_jobs_args(args),
103103
"pipeline": lambda: process_pipeline_args(args, catalog),

src/datachain/cli/commands/misc.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,28 @@ def clear_cache(catalog: "Catalog"):
1010
catalog.cache.clear()
1111

1212

13-
def garbage_collect(catalog: "Catalog"):
13+
def garbage_collect(catalog: "Catalog", checkpoint_ttl: int | None = None):
14+
print("Collecting temporary tables...")
1415
temp_tables = catalog.get_temp_table_names()
15-
num_versions_removed = catalog.cleanup_failed_dataset_versions()
16-
17-
total_cleaned = len(temp_tables) + num_versions_removed
16+
if temp_tables:
17+
catalog.cleanup_tables(temp_tables)
18+
print(f" Removed {len(temp_tables)} temporary tables.")
19+
else:
20+
print(" No temporary tables to clean up.")
1821

19-
if total_cleaned == 0:
20-
print("Nothing to clean up.")
22+
print("Collecting failed dataset versions...")
23+
num_versions = catalog.cleanup_failed_dataset_versions()
24+
if num_versions:
25+
print(f" Removed {num_versions} failed/incomplete dataset versions.")
2126
else:
22-
if temp_tables:
23-
print(f"Garbage collecting {len(temp_tables)} tables.")
24-
catalog.cleanup_tables(temp_tables)
27+
print(" No failed dataset versions to clean up.")
2528

26-
if num_versions_removed:
27-
print(f"Cleaned {num_versions_removed} failed/incomplete dataset versions.")
29+
print("Collecting outdated checkpoints...")
30+
num_checkpoints = catalog.cleanup_checkpoints(ttl_seconds=checkpoint_ttl)
31+
if num_checkpoints:
32+
print(f" Removed {num_checkpoints} outdated checkpoints.")
33+
else:
34+
print(" No outdated checkpoints to clean up.")
2835

2936

3037
def completion(shell: str) -> str:

src/datachain/cli/parser/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,22 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
478478
parse_gc = subp.add_parser(
479479
"gc",
480480
parents=[parent_parser],
481-
description="Garbage collect temporary tables and failed dataset versions.",
481+
description=(
482+
"Garbage collect temporary tables,"
483+
" failed dataset versions, and outdated checkpoints."
484+
),
482485
formatter_class=CustomHelpFormatter,
483486
)
487+
parse_gc.add_argument(
488+
"--checkpoint-ttl",
489+
type=int,
490+
default=None,
491+
help=(
492+
"Time-to-live for checkpoints in seconds."
493+
" Checkpoints older than this are removed."
494+
" Defaults to 4 hours (14400 seconds)."
495+
),
496+
)
484497
add_anon_arg(parse_gc)
485498

486499
subp.add_parser("internal-run-udf", parents=[parent_parser])

src/datachain/data_storage/db_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ def has_table(self, name: str) -> bool:
123123
return sa.inspect(self.engine).has_table(name)
124124

125125
@abstractmethod
126-
def list_tables(self, prefix: str = "") -> list[str]:
126+
def list_tables(self, pattern: str = "") -> list[str]:
127127
"""
128-
List all table names, optionally filtered by prefix.
128+
List all table names, optionally filtered by a SQL LIKE pattern.
129129
130130
Args:
131-
prefix: Optional prefix to filter table names
131+
pattern: SQL LIKE pattern to filter table names (e.g. 'udf_%')
132132
133133
Returns:
134-
List of table names matching the prefix
134+
List of table names matching the pattern
135135
"""
136136

137137
@abstractmethod

0 commit comments

Comments
 (0)