Skip to content

Commit a0eda41

Browse files
shchekleinCopilotCopilot
authored
Make pull and read_dataset from Studio atomic (#1573)
* add basic pull failure and cleanup tests * cleanup remote read_dataset tests * add cleanup on failure to read_dataset remote * remote read: kill in the midle cleanup * firs pass to implement this * address review findings * add more tests * keep addressing edge concurrency edge cases * more tests, more edge case, cleanup messages * use file lock to wait on concurrent pulls * address more review comments * fix test failing on SaaS * Update src/datachain/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/datachain/catalog/catalog.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * addres more review comments * address more reviews * Use `except Exception:` instead of bare `except:` for cleanup-and-reraise (#1585) * Initial plan * Use `except Exception:` instead of bare `except:` for better interrupt handling Co-authored-by: shcheklein <3659196+shcheklein@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: shcheklein <3659196+shcheklein@users.noreply.github.com> * address more reviews * add more tests for pull * fix tests coverage * address PR reviews * debug CI failure * fix lock file tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: shcheklein <3659196+shcheklein@users.noreply.github.com>
1 parent 47878dd commit a0eda41

File tree

11 files changed

+1626
-337
lines changed

11 files changed

+1626
-337
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dependencies = [
5151
"huggingface_hub",
5252
"iterative-telemetry>=0.0.10",
5353
"platformdirs",
54+
"filelock",
5455
"dvc-studio-client>=0.21,<1",
5556
"tabulate",
5657
"websockets",

src/datachain/catalog/catalog.py

Lines changed: 235 additions & 185 deletions
Large diffs are not rendered by default.

src/datachain/data_storage/metastore.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections.abc import Iterator
77
from contextlib import contextmanager, nullcontext, suppress
8-
from datetime import datetime, timezone
8+
from datetime import datetime, timedelta, timezone
99
from functools import cached_property, reduce
1010
from itertools import groupby
1111
from typing import TYPE_CHECKING, Any
@@ -32,6 +32,7 @@
3232
cast,
3333
desc,
3434
literal,
35+
or_,
3536
select,
3637
)
3738
from sqlalchemy.sql import func as f
@@ -71,6 +72,11 @@
7172
from datachain.namespace import Namespace
7273
from datachain.project import Project
7374

75+
# Versions with no job_id (e.g. from pull_dataset) are only eligible
76+
# for gc cleanup if they are older than this threshold, to avoid
77+
# cleaning up in-flight operations.
78+
STALE_CREATED_THRESHOLD_HOURS = 1
79+
7480
if TYPE_CHECKING:
7581
from sqlalchemy import CTE, Delete, Insert, Select, Subquery, Update
7682
from sqlalchemy.schema import SchemaItem
@@ -335,20 +341,18 @@ def get_incomplete_dataset_versions(
335341
self, job_id: str | None = None
336342
) -> list[tuple[DatasetRecord, str]]:
337343
"""
338-
Get failed/incomplete dataset versions that are in complete job. This is
339-
used to get versions to cleanup.
344+
Get incomplete dataset versions to clean up.
340345
341-
Returns dataset versions that:
342-
- Have status CREATED or FAILED (incomplete/failed)
343-
- Belong to jobs that are not running (COMPLETE, FAILED, CANCELED)
346+
When job_id is provided, returns versions belonging to that specific
347+
job (used during job failure cleanup).
344348
345-
Cleans both CREATED and FAILED to handle edge cases:
346-
- FAILED: Explicitly marked failed versions
347-
- CREATED: Orphaned versions from crashes/bugs (before failure marking)
349+
When job_id is None, returns all incomplete dataset versions
350+
whose associated job is finished, plus versions with no job_id
351+
that are older than STALE_CREATED_THRESHOLD_HOURS (used by gc).
348352
349353
Returns:
350354
List of (DatasetRecord, version_string) tuples. Each DatasetRecord
351-
contains only one version (the failed version to clean).
355+
contains only one version (the incomplete version to clean).
352356
"""
353357

354358
@abstractmethod
@@ -373,6 +377,14 @@ def list_datasets_by_prefix(
373377
projects.
374378
"""
375379

380+
def get_dataset_by_version_uuid(
381+
self,
382+
uuid: str,
383+
include_incomplete: bool = False,
384+
) -> DatasetRecord:
385+
"""Gets a dataset that contains a version with the given UUID."""
386+
raise NotImplementedError
387+
376388
@abstractmethod
377389
def get_dataset(
378390
self,
@@ -1540,6 +1552,20 @@ def list_datasets_by_prefix(
15401552
query = query.where(self._datasets.c.name.startswith(prefix))
15411553
yield from self._parse_dataset_list(self.db.execute(query))
15421554

1555+
def get_dataset_by_version_uuid(
1556+
self,
1557+
uuid: str,
1558+
include_incomplete: bool = False,
1559+
) -> DatasetRecord:
1560+
"""Gets a dataset that contains a version with the given UUID."""
1561+
dv = self._datasets_versions
1562+
query = self._base_dataset_query(include_incomplete=include_incomplete)
1563+
query = query.where(dv.c.uuid == uuid)
1564+
ds = self._parse_dataset(self.db.execute(query))
1565+
if not ds:
1566+
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")
1567+
return ds
1568+
15431569
def get_dataset(
15441570
self,
15451571
name: str, # normal, not full dataset name
@@ -1613,29 +1639,50 @@ def get_incomplete_dataset_versions(
16131639
dv = self._datasets_versions
16141640
j = self._jobs
16151641

1616-
# Query dataset + version info for failed versions from non-running jobs
1642+
select_cols = (
1643+
*(getattr(n.c, f) for f in self._namespaces_fields),
1644+
*(getattr(p.c, f) for f in self._projects_fields),
1645+
*(getattr(d.c, f) for f in self._dataset_fields),
1646+
*(getattr(dv.c, f) for f in self._dataset_version_fields),
1647+
)
1648+
base_from = (
1649+
n.join(p, n.c.id == p.c.namespace_id)
1650+
.join(d, p.c.id == d.c.project_id)
1651+
.join(dv, d.c.id == dv.c.dataset_id)
1652+
)
1653+
1654+
# LEFT JOIN on jobs so versions with job_id=NULL are included.
1655+
# Only skip versions whose job is still running.
16171656
query = (
1618-
self._datasets_select(
1619-
*(getattr(n.c, f) for f in self._namespaces_fields),
1620-
*(getattr(p.c, f) for f in self._projects_fields),
1621-
*(getattr(d.c, f) for f in self._dataset_fields),
1622-
*(getattr(dv.c, f) for f in self._dataset_version_fields),
1623-
)
1657+
self._datasets_select(*select_cols)
16241658
.select_from(
1625-
n.join(p, n.c.id == p.c.namespace_id)
1626-
.join(d, p.c.id == d.c.project_id)
1627-
.join(dv, d.c.id == dv.c.dataset_id)
1628-
.join(j, cast(dv.c.job_id, j.c.id.type) == j.c.id)
1659+
base_from.join(
1660+
j,
1661+
cast(dv.c.job_id, j.c.id.type) == j.c.id,
1662+
isouter=True,
1663+
)
16291664
)
16301665
.where(
16311666
dv.c.status.in_([DatasetStatus.CREATED, DatasetStatus.FAILED]),
1632-
j.c.status.in_(
1633-
[JobStatus.COMPLETE, JobStatus.FAILED, JobStatus.CANCELED]
1667+
or_(
1668+
# job is finished
1669+
j.c.status.in_(
1670+
[JobStatus.COMPLETE, JobStatus.FAILED, JobStatus.CANCELED]
1671+
),
1672+
# or no job at all (e.g. pull_dataset) — but only
1673+
# if old enough to not be an in-flight operation
1674+
and_(
1675+
dv.c.job_id.is_(None),
1676+
dv.c.created_at
1677+
< datetime.now(timezone.utc)
1678+
- timedelta(hours=STALE_CREATED_THRESHOLD_HOURS),
1679+
),
16341680
),
16351681
)
16361682
)
1683+
16371684
if job_id:
1638-
query = query.where(j.c.id == job_id)
1685+
query = query.where(dv.c.job_id == job_id)
16391686

16401687
# Parse results and return (dataset, version) tuples
16411688
results = []

src/datachain/data_storage/sqlite.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ def get_retry_sleep_sec(retry_count: int) -> int:
8585
return RETRY_START_SEC * (RETRY_FACTOR**retry_count)
8686

8787

88+
SQLITE_BUSY = 5
89+
SQLITE_LOCKED = 6
90+
91+
92+
def _is_sqlite_lock_error(exc: sqlite3.OperationalError) -> bool:
93+
"""Return True if the OperationalError is a transient lock/busy error."""
94+
code = getattr(exc, "sqlite_errorcode", None)
95+
if code is not None:
96+
# Python >=3.11: use the precise error code
97+
return code in (SQLITE_BUSY, SQLITE_LOCKED)
98+
# Python 3.10: fall back to message matching
99+
msg = str(exc).lower()
100+
return "locked" in msg or "busy" in msg
101+
102+
88103
def retry_sqlite_locks(func):
89104
# This retries the database modification in case of concurrent access
90105
@wraps(func)
@@ -94,6 +109,8 @@ def wrapper(*args, **kwargs):
94109
try:
95110
return func(*args, **kwargs)
96111
except sqlite3.OperationalError as operror:
112+
if not _is_sqlite_lock_error(operror):
113+
raise
97114
exc = operror
98115
sleep(get_retry_sleep_sec(retry_count))
99116
raise exc
@@ -158,14 +175,24 @@ def _connect(
158175
# ensure we run SA on_connect init (e.g it registers regexp function),
159176
# also makes sure that it's consistent. Otherwise in some cases it
160177
# seems we are getting different results if engine object is used in a
161-
# different thread first and enine is not used in the Main thread.
178+
# different thread first and engine is not used in the Main thread.
162179
engine.connect().close()
163180

164181
db.isolation_level = None # Use autocommit mode
165182
db.execute("PRAGMA foreign_keys = ON")
166183
db.execute("PRAGMA cache_size = -102400") # 100 MiB
167-
# Enable Write-Ahead Log Journaling
168-
db.execute("PRAGMA journal_mode = WAL")
184+
# Switching to WAL requires an exclusive lock, so retry briefly
185+
# in case another process is initializing the same DB file.
186+
for _ in range(5):
187+
try:
188+
db.execute("PRAGMA journal_mode = WAL")
189+
break
190+
except sqlite3.OperationalError as e:
191+
if not _is_sqlite_lock_error(e):
192+
raise
193+
sleep(1)
194+
else:
195+
db.execute("PRAGMA journal_mode = WAL") # final attempt, let it raise
169196
db.execute("PRAGMA synchronous = NORMAL")
170197
db.execute("PRAGMA case_sensitive_like = ON")
171198

@@ -847,10 +874,6 @@ def get_buffer(
847874
)
848875
return self.buffers[table.name]
849876

850-
def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
851-
dr = self.dataset_rows(dataset, version)
852-
return self.db.insert_dataframe(dr.table.name, df)
853-
854877
def instr(self, source, target) -> "ColumnElement":
855878
return cast(func.instr(source, target), sqlalchemy.Boolean)
856879

src/datachain/data_storage/warehouse.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,14 @@ def create_dataset_rows_table(
391391
) -> sa.Table:
392392
"""Creates a dataset rows table for the given dataset name and columns"""
393393

394+
def insert_dataframe_to_table(self, table_name: str, df) -> int:
395+
"""
396+
Insert dataframe into any table by name.
397+
398+
This is used for inserting data into temporary staging tables.
399+
"""
400+
return self.db.insert_dataframe(table_name, df)
401+
394402
def drop_dataset_rows_table(
395403
self,
396404
dataset: DatasetRecord,
@@ -538,10 +546,6 @@ def insert_rows_done(self, table: sa.Table) -> None:
538546
"""Signal that row inserts are complete by flushing and closing the buffer."""
539547
self.close_buffer(table)
540548

541-
@abstractmethod
542-
def insert_dataset_rows(self, df, dataset: DatasetRecord, version: str) -> int:
543-
"""Inserts dataset rows directly into dataset table"""
544-
545549
@abstractmethod
546550
def instr(self, source, target) -> sa.ColumnElement:
547551
"""

src/datachain/utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,91 @@ def find(cls, create: bool = True) -> "Self":
122122
return instance
123123

124124

125+
@contextmanager
126+
def interprocess_file_lock(
127+
lock_path: str,
128+
*,
129+
wait_message: str | None = None,
130+
timeout: float = -1,
131+
) -> Iterator[None]:
132+
"""Acquire an inter-process lock backed by a file.
133+
134+
Intended for local-only concurrency control (multiple CLI processes sharing
135+
the same DataChainDir). Locks are released automatically by the OS when the
136+
process exits, including on SIGKILL.
137+
138+
Uses `filelock.FileLock` (OS-level file locking).
139+
"""
140+
141+
from filelock import FileLock, Timeout
142+
143+
lock_dir = osp.dirname(lock_path)
144+
if lock_dir:
145+
os.makedirs(lock_dir, exist_ok=True)
146+
lock = FileLock(lock_path)
147+
pid_path = f"{lock_path}.pid"
148+
149+
def _read_pid() -> int | None:
150+
try:
151+
with open(pid_path, encoding="utf-8") as f:
152+
raw = f.read().strip()
153+
return int(raw) if raw else None
154+
except Exception: # noqa: BLE001
155+
return None
156+
157+
def _write_pid() -> None:
158+
try:
159+
with open(pid_path, "w", encoding="utf-8") as f:
160+
f.write(str(os.getpid()))
161+
except Exception:
162+
logger.debug(
163+
"Failed to write PID into lock file %s",
164+
pid_path,
165+
exc_info=True,
166+
)
167+
168+
def _print_wait_hint(pid: int | None) -> None:
169+
if not wait_message:
170+
return
171+
pid_str = f" (pid={pid})" if pid is not None else ""
172+
if pid is not None:
173+
check_hint = (
174+
f"If this looks stuck, first check the PID is running "
175+
f"(e.g. `ps -p {pid}`), then if you are sure no process is "
176+
f"running delete: {lock_path} (and {pid_path})"
177+
)
178+
else:
179+
check_hint = f"If this looks stuck, delete: {lock_path} (and {pid_path})"
180+
print(f"{wait_message}{pid_str}\n{check_hint}")
181+
182+
acquired = False
183+
try:
184+
if wait_message:
185+
try:
186+
lock.acquire(timeout=0)
187+
except Timeout:
188+
_print_wait_hint(_read_pid())
189+
lock.acquire(timeout=timeout)
190+
else:
191+
lock.acquire(timeout=timeout)
192+
193+
acquired = True
194+
_write_pid()
195+
yield
196+
finally:
197+
if acquired:
198+
try:
199+
os.remove(pid_path)
200+
except OSError:
201+
logger.debug(
202+
"Failed to remove PID file %s during lock cleanup",
203+
pid_path,
204+
exc_info=True,
205+
)
206+
finally:
207+
lock.release()
208+
209+
125210
@dataclass
126211
class DatasetIdentifier:
127212
namespace: str

0 commit comments

Comments
 (0)