Skip to content

Commit 07a5e4d

Browse files
authored
Add metastore DB tests (#1122)
1 parent 76b5aef commit 07a5e4d

File tree

3 files changed

+978
-52
lines changed

3 files changed

+978
-52
lines changed

src/datachain/catalog/catalog.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -911,11 +911,7 @@ def update_dataset_version_with_warehouse_info(
911911
values["num_objects"] = None
912912
values["size"] = None
913913
values["preview"] = None
914-
self.metastore.update_dataset_version(
915-
dataset,
916-
version,
917-
**values,
918-
)
914+
self.metastore.update_dataset_version(dataset, version, **values)
919915
return
920916

921917
if not dataset_version.num_objects:
@@ -935,11 +931,7 @@ def update_dataset_version_with_warehouse_info(
935931
if not values:
936932
return
937933

938-
self.metastore.update_dataset_version(
939-
dataset,
940-
version,
941-
**values,
942-
)
934+
self.metastore.update_dataset_version(dataset, version, **values)
943935

944936
def update_dataset(
945937
self, dataset: DatasetRecord, conn=None, **kwargs

src/datachain/data_storage/metastore.py

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from datachain.error import (
3838
DatasetNotFoundError,
39+
DatasetVersionNotFoundError,
3940
TableMissingError,
4041
)
4142
from datachain.job import Job
@@ -273,7 +274,6 @@ def update_job(
273274
self,
274275
job_id: str,
275276
status: Optional[JobStatus] = None,
276-
exit_code: Optional[int] = None,
277277
error_message: Optional[str] = None,
278278
error_stack: Optional[str] = None,
279279
finished_at: Optional[datetime] = None,
@@ -620,22 +620,36 @@ def update_dataset(
620620
self, dataset: DatasetRecord, conn=None, **kwargs
621621
) -> DatasetRecord:
622622
"""Updates dataset fields."""
623-
values = {}
624-
dataset_values = {}
623+
values: dict[str, Any] = {}
624+
dataset_values: dict[str, Any] = {}
625625
for field, value in kwargs.items():
626-
if field in self._dataset_fields[1:]:
627-
if field in ["attrs", "schema"]:
628-
values[field] = json.dumps(value) if value else None
626+
if field in ("id", "created_at") or field not in self._dataset_fields:
627+
continue # these fields are read-only or not applicable
628+
629+
if value is None and field in ("name", "status", "sources", "query_script"):
630+
raise ValueError(f"Field {field} cannot be None")
631+
if field == "name" and not value:
632+
raise ValueError("name cannot be empty")
633+
634+
if field == "attrs":
635+
if value is None:
636+
values[field] = None
629637
else:
630-
values[field] = value
631-
if field == "schema":
632-
dataset_values[field] = DatasetRecord.parse_schema(value)
638+
values[field] = json.dumps(value)
639+
dataset_values[field] = value
640+
elif field == "schema":
641+
if value is None:
642+
values[field] = None
643+
dataset_values[field] = None
633644
else:
634-
dataset_values[field] = value
645+
values[field] = json.dumps(value)
646+
dataset_values[field] = DatasetRecord.parse_schema(value)
647+
else:
648+
values[field] = value
649+
dataset_values[field] = value
635650

636651
if not values:
637-
# Nothing to update
638-
return dataset
652+
return dataset # nothing to update
639653

640654
d = self._datasets
641655
self.db.execute(
@@ -651,36 +665,70 @@ def update_dataset_version(
651665
self, dataset: DatasetRecord, version: str, conn=None, **kwargs
652666
) -> DatasetVersion:
653667
"""Updates dataset fields."""
654-
dataset_version = dataset.get_version(version)
655-
656-
values = {}
657-
version_values: dict = {}
668+
values: dict[str, Any] = {}
669+
version_values: dict[str, Any] = {}
658670
for field, value in kwargs.items():
659-
if field in self._dataset_version_fields[1:]:
660-
if field == "schema":
661-
values[field] = json.dumps(value) if value else None
662-
version_values[field] = DatasetRecord.parse_schema(value)
663-
elif field == "feature_schema":
664-
values[field] = json.dumps(value) if value else None
665-
version_values[field] = value
666-
elif field == "preview" and isinstance(value, list):
667-
values[field] = json.dumps(value, cls=JSONSerialize)
668-
version_values[field] = value
671+
if (
672+
field in ("id", "created_at")
673+
or field not in self._dataset_version_fields
674+
):
675+
continue # these fields are read-only or not applicable
676+
677+
if value is None and field in (
678+
"status",
679+
"sources",
680+
"query_script",
681+
"error_message",
682+
"error_stack",
683+
"script_output",
684+
"uuid",
685+
):
686+
raise ValueError(f"Field {field} cannot be None")
687+
688+
if field == "schema":
689+
values[field] = json.dumps(value) if value else None
690+
version_values[field] = (
691+
DatasetRecord.parse_schema(value) if value else None
692+
)
693+
elif field == "feature_schema":
694+
if value is None:
695+
values[field] = None
696+
else:
697+
values[field] = json.dumps(value)
698+
version_values[field] = value
699+
elif field == "preview":
700+
if value is None:
701+
values[field] = None
702+
elif not isinstance(value, list):
703+
raise ValueError(
704+
f"Field '{field}' must be a list, got {type(value).__name__}"
705+
)
669706
else:
670-
values[field] = value
671-
version_values[field] = value
707+
values[field] = json.dumps(value, cls=JSONSerialize)
708+
version_values["_preview_data"] = value
709+
else:
710+
values[field] = value
711+
version_values[field] = value
672712

673-
if values:
674-
dv = self._datasets_versions
675-
self.db.execute(
676-
self._datasets_versions_update()
677-
.where(dv.c.dataset_id == dataset.id, dv.c.version == version)
678-
.values(values),
679-
conn=conn,
680-
) # type: ignore [attr-defined]
681-
dataset_version.update(**version_values)
713+
if not values:
714+
return dataset.get_version(version)
715+
716+
dv = self._datasets_versions
717+
self.db.execute(
718+
self._datasets_versions_update()
719+
.where(dv.c.dataset_id == dataset.id, dv.c.version == version)
720+
.values(values),
721+
conn=conn,
722+
) # type: ignore [attr-defined]
723+
724+
for v in dataset.versions:
725+
if v.version == version:
726+
v.update(**version_values)
727+
return v
682728

683-
return dataset_version
729+
raise DatasetVersionNotFoundError(
730+
f"Dataset {dataset.name} does not have version {version}"
731+
)
684732

685733
def _parse_dataset(self, rows) -> Optional[DatasetRecord]:
686734
versions = [self.dataset_class.parse(*r) for r in rows]
@@ -812,7 +860,7 @@ def update_dataset_status(
812860
update_data["error_message"] = error_message
813861
update_data["error_stack"] = error_stack
814862

815-
self.update_dataset(dataset, conn=conn, **update_data)
863+
dataset = self.update_dataset(dataset, conn=conn, **update_data)
816864

817865
if version:
818866
self.update_dataset_version(dataset, version, conn=conn, **update_data)
@@ -1064,7 +1112,6 @@ def update_job(
10641112
self,
10651113
job_id: str,
10661114
status: Optional[JobStatus] = None,
1067-
exit_code: Optional[int] = None,
10681115
error_message: Optional[str] = None,
10691116
error_stack: Optional[str] = None,
10701117
finished_at: Optional[datetime] = None,
@@ -1075,8 +1122,6 @@ def update_job(
10751122
values: dict = {}
10761123
if status is not None:
10771124
values["status"] = status
1078-
if exit_code is not None:
1079-
values["exit_code"] = exit_code
10801125
if error_message is not None:
10811126
values["error_message"] = error_message
10821127
if error_stack is not None:

0 commit comments

Comments
 (0)