Skip to content

Commit 3d48a91

Browse files
authored
Distributed compute refactoring (#1047)
* Distributed compute refactoring * Remove tqdm from warehouse cleanup * Batching refactoring * Metastore updates * Tests
1 parent 81e9a35 commit 3d48a91

31 files changed

+839
-424
lines changed

.github/workflows/tests-studio.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ jobs:
9898
- name: Run tests
9999
# Generate `.test_durations` file with `pytest --store-durations --durations-path ../.github/.test_durations ...`
100100
run: >
101+
PYTHONPATH="$(pwd)/..:${PYTHONPATH}"
101102
pytest
102103
--config-file=pyproject.toml -rs
103104
--splits=6 --group=${{ matrix.group }} --durations-path=../../.github/.test_durations

src/datachain/catalog/catalog.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
8080
# exit code we use if query script was canceled
8181
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
82+
QUERY_SCRIPT_SIGTERM_EXIT_CODE = -15 # if query script was terminated by SIGTERM
8283

8384
# dataset pull
8485
PULL_DATASET_MAX_THREADS = 5
@@ -1645,7 +1646,10 @@ def raise_termination_signal(sig: int, _: Any) -> NoReturn:
16451646
thread.join() # wait for the reader thread
16461647

16471648
logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
1648-
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
1649+
if proc.returncode in (
1650+
QUERY_SCRIPT_CANCELED_EXIT_CODE,
1651+
QUERY_SCRIPT_SIGTERM_EXIT_CODE,
1652+
):
16491653
raise QueryScriptCancelError(
16501654
"Query script was canceled by user",
16511655
return_code=proc.returncode,

src/datachain/cli/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def main(argv: Optional[list[str]] = None) -> int:
3434
datachain_parser = get_parser()
3535
args = datachain_parser.parse_args(argv)
3636

37-
if args.command in ("internal-run-udf", "internal-run-udf-worker"):
38-
return handle_udf(args.command)
37+
if args.command == "internal-run-udf":
38+
return handle_udf()
39+
if args.command == "internal-run-udf-worker":
40+
return handle_udf_runner(args.fd)
3941

4042
if args.command is None:
4143
datachain_parser.print_help(sys.stderr)
@@ -303,13 +305,13 @@ def handle_general_exception(exc, args, logging_level):
303305
return error, 1
304306

305307

306-
def handle_udf(command):
307-
if command == "internal-run-udf":
308-
from datachain.query.dispatch import udf_entrypoint
308+
def handle_udf() -> int:
309+
from datachain.query.dispatch import udf_entrypoint
309310

310-
return udf_entrypoint()
311+
return udf_entrypoint()
311312

312-
if command == "internal-run-udf-worker":
313-
from datachain.query.dispatch import udf_worker_entrypoint
314313

315-
return udf_worker_entrypoint()
314+
def handle_udf_runner(fd: Optional[int] = None) -> int:
315+
from datachain.query.dispatch import udf_worker_entrypoint
316+
317+
return udf_worker_entrypoint(fd)

src/datachain/cli/commands/query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def query(
2929
name=os.path.basename(script),
3030
query=script_content,
3131
query_type=JobQueryType.PYTHON,
32+
status=JobStatus.RUNNING,
3233
python_version=python_version,
3334
params=params,
3435
)

src/datachain/cli/parser/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,15 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
549549
add_anon_arg(parse_gc)
550550

551551
subp.add_parser("internal-run-udf", parents=[parent_parser])
552-
subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
552+
run_udf_worker = subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
553+
run_udf_worker.add_argument(
554+
"--fd",
555+
type=int,
556+
action="store",
557+
default=None,
558+
help="File descriptor to write results to",
559+
)
560+
553561
add_completion_parser(subp, [parent_parser])
554562
return parser
555563

src/datachain/data_storage/job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
class JobStatus(int, Enum):
55
CREATED = 1
6+
SCHEDULED = 10
67
QUEUED = 2
78
INIT = 3
89
RUNNING = 4

src/datachain/data_storage/metastore.py

Lines changed: 82 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def create_job(
254254
name: str,
255255
query: str,
256256
query_type: JobQueryType = JobQueryType.PYTHON,
257+
status: JobStatus = JobStatus.CREATED,
257258
workers: int = 1,
258259
python_version: Optional[str] = None,
259260
params: Optional[dict[str, str]] = None,
@@ -264,33 +265,35 @@ def create_job(
264265
"""
265266

266267
@abstractmethod
267-
def set_job_status(
268+
def get_job(self, job_id: str) -> Optional[Job]:
269+
"""Returns the job with the given ID."""
270+
271+
@abstractmethod
272+
def update_job(
268273
self,
269274
job_id: str,
270-
status: JobStatus,
275+
status: Optional[JobStatus] = None,
276+
exit_code: Optional[int] = None,
271277
error_message: Optional[str] = None,
272278
error_stack: Optional[str] = None,
279+
finished_at: Optional[datetime] = None,
273280
metrics: Optional[dict[str, Any]] = None,
274-
) -> None:
275-
"""Set the status of the given job."""
281+
) -> Optional["Job"]:
282+
"""Updates job fields."""
276283

277284
@abstractmethod
278-
def get_job_status(self, job_id: str) -> Optional[JobStatus]:
279-
"""Returns the status of the given job."""
280-
281-
@abstractmethod
282-
def set_job_and_dataset_status(
285+
def set_job_status(
283286
self,
284287
job_id: str,
285-
job_status: JobStatus,
286-
dataset_status: DatasetStatus,
288+
status: JobStatus,
289+
error_message: Optional[str] = None,
290+
error_stack: Optional[str] = None,
287291
) -> None:
288-
"""Set the status of the given job and dataset."""
292+
"""Set the status of the given job."""
289293

290294
@abstractmethod
291-
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
292-
"""Returns dataset names and versions for the job."""
293-
raise NotImplementedError
295+
def get_job_status(self, job_id: str) -> Optional[JobStatus]:
296+
"""Returns the status of the given job."""
294297

295298

296299
class AbstractDBMetastore(AbstractMetastore):
@@ -651,30 +654,31 @@ def update_dataset_version(
651654
dataset_version = dataset.get_version(version)
652655

653656
values = {}
657+
version_values: dict = {}
654658
for field, value in kwargs.items():
655659
if field in self._dataset_version_fields[1:]:
656660
if field == "schema":
657-
dataset_version.update(**{field: DatasetRecord.parse_schema(value)})
658661
values[field] = json.dumps(value) if value else None
662+
version_values[field] = DatasetRecord.parse_schema(value)
659663
elif field == "feature_schema":
660664
values[field] = json.dumps(value) if value else None
665+
version_values[field] = value
661666
elif field == "preview" and isinstance(value, list):
662667
values[field] = json.dumps(value, cls=JSONSerialize)
668+
version_values[field] = value
663669
else:
664670
values[field] = value
665-
dataset_version.update(**{field: value})
666-
667-
if not values:
668-
# Nothing to update
669-
return dataset_version
671+
version_values[field] = value
670672

671-
dv = self._datasets_versions
672-
self.db.execute(
673-
self._datasets_versions_update()
674-
.where(dv.c.id == dataset_version.id)
675-
.values(values),
676-
conn=conn,
677-
) # type: ignore [attr-defined]
673+
if values:
674+
dv = self._datasets_versions
675+
self.db.execute(
676+
self._datasets_versions_update()
677+
.where(dv.c.dataset_id == dataset.id and dv.c.version == version)
678+
.values(values),
679+
conn=conn,
680+
) # type: ignore [attr-defined]
681+
dataset_version.update(**version_values)
678682

679683
return dataset_version
680684

@@ -702,7 +706,7 @@ def _get_dataset_query(
702706
dataset_fields: list[str],
703707
dataset_version_fields: list[str],
704708
isouter: bool = True,
705-
):
709+
) -> "Select":
706710
if not (
707711
self.db.has_table(self._datasets.name)
708712
and self.db.has_table(self._datasets_versions.name)
@@ -719,12 +723,12 @@ def _get_dataset_query(
719723
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
720724
return query.select_from(j)
721725

722-
def _base_dataset_query(self):
726+
def _base_dataset_query(self) -> "Select":
723727
return self._get_dataset_query(
724728
self._dataset_fields, self._dataset_version_fields
725729
)
726730

727-
def _base_list_datasets_query(self):
731+
def _base_list_datasets_query(self) -> "Select":
728732
return self._get_dataset_query(
729733
self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
730734
)
@@ -1018,6 +1022,7 @@ def create_job(
10181022
name: str,
10191023
query: str,
10201024
query_type: JobQueryType = JobQueryType.PYTHON,
1025+
status: JobStatus = JobStatus.CREATED,
10211026
workers: int = 1,
10221027
python_version: Optional[str] = None,
10231028
params: Optional[dict[str, str]] = None,
@@ -1032,7 +1037,7 @@ def create_job(
10321037
self._jobs_insert().values(
10331038
id=job_id,
10341039
name=name,
1035-
status=JobStatus.CREATED,
1040+
status=status,
10361041
created_at=datetime.now(timezone.utc),
10371042
query=query,
10381043
query_type=query_type.value,
@@ -1047,25 +1052,65 @@ def create_job(
10471052
)
10481053
return job_id
10491054

1055+
def get_job(self, job_id: str, conn=None) -> Optional[Job]:
1056+
"""Returns the job with the given ID."""
1057+
query = self._jobs_select(self._jobs).where(self._jobs.c.id == job_id)
1058+
results = list(self.db.execute(query, conn=conn))
1059+
if not results:
1060+
return None
1061+
return self._parse_job(results[0])
1062+
1063+
def update_job(
1064+
self,
1065+
job_id: str,
1066+
status: Optional[JobStatus] = None,
1067+
exit_code: Optional[int] = None,
1068+
error_message: Optional[str] = None,
1069+
error_stack: Optional[str] = None,
1070+
finished_at: Optional[datetime] = None,
1071+
metrics: Optional[dict[str, Any]] = None,
1072+
conn: Optional[Any] = None,
1073+
) -> Optional["Job"]:
1074+
"""Updates job fields."""
1075+
values: dict = {}
1076+
if status is not None:
1077+
values["status"] = status
1078+
if exit_code is not None:
1079+
values["exit_code"] = exit_code
1080+
if error_message is not None:
1081+
values["error_message"] = error_message
1082+
if error_stack is not None:
1083+
values["error_stack"] = error_stack
1084+
if finished_at is not None:
1085+
values["finished_at"] = finished_at
1086+
if metrics:
1087+
values["metrics"] = json.dumps(metrics)
1088+
1089+
if values:
1090+
j = self._jobs
1091+
self.db.execute(
1092+
self._jobs_update().where(j.c.id == job_id).values(**values),
1093+
conn=conn,
1094+
) # type: ignore [attr-defined]
1095+
1096+
return self.get_job(job_id, conn=conn)
1097+
10501098
def set_job_status(
10511099
self,
10521100
job_id: str,
10531101
status: JobStatus,
10541102
error_message: Optional[str] = None,
10551103
error_stack: Optional[str] = None,
1056-
metrics: Optional[dict[str, Any]] = None,
10571104
conn: Optional[Any] = None,
10581105
) -> None:
10591106
"""Set the status of the given job."""
1060-
values: dict = {"status": status.value}
1061-
if status.value in JobStatus.finished():
1107+
values: dict = {"status": status}
1108+
if status in JobStatus.finished():
10621109
values["finished_at"] = datetime.now(timezone.utc)
10631110
if error_message:
10641111
values["error_message"] = error_message
10651112
if error_stack:
10661113
values["error_stack"] = error_stack
1067-
if metrics:
1068-
values["metrics"] = json.dumps(metrics)
10691114
self.db.execute(
10701115
self._jobs_update(self._jobs.c.id == job_id).values(**values),
10711116
conn=conn,
@@ -1086,37 +1131,3 @@ def get_job_status(
10861131
if not results:
10871132
return None
10881133
return results[0][0]
1089-
1090-
def set_job_and_dataset_status(
1091-
self,
1092-
job_id: str,
1093-
job_status: JobStatus,
1094-
dataset_status: DatasetStatus,
1095-
) -> None:
1096-
"""Set the status of the given job and dataset."""
1097-
with self.db.transaction() as conn:
1098-
self.set_job_status(job_id, status=job_status, conn=conn)
1099-
dv = self._datasets_versions
1100-
query = (
1101-
self._datasets_versions_update()
1102-
.where(
1103-
(dv.c.job_id == job_id) & (dv.c.status != DatasetStatus.COMPLETE)
1104-
)
1105-
.values(status=dataset_status)
1106-
)
1107-
self.db.execute(query, conn=conn) # type: ignore[attr-defined]
1108-
1109-
def get_job_dataset_versions(self, job_id: str) -> list[tuple[str, int]]:
1110-
"""Returns dataset names and versions for the job."""
1111-
dv = self._datasets_versions
1112-
ds = self._datasets
1113-
1114-
join_condition = dv.c.dataset_id == ds.c.id
1115-
1116-
query = (
1117-
self._datasets_versions_select(ds.c.name, dv.c.version)
1118-
.select_from(dv.join(ds, join_condition))
1119-
.where(dv.c.job_id == job_id)
1120-
)
1121-
1122-
return list(self.db.execute(query))

0 commit comments

Comments
 (0)