Skip to content

Commit 41af6ff

Browse files
authored
Added update_versions argument to DataChain.save() (#1100)
Added `update_versions` argument to `DataChain.save()`
1 parent e101351 commit 41af6ff

File tree

5 files changed

+66
-16
lines changed

5 files changed

+66
-16
lines changed

src/datachain/catalog/catalog.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ def create_dataset(
779779
uuid: Optional[str] = None,
780780
description: Optional[str] = None,
781781
attrs: Optional[list[str]] = None,
782+
update_version: Optional[str] = "patch",
782783
) -> "DatasetRecord":
783784
"""
784785
Creates new dataset of a specific version.
@@ -795,6 +796,11 @@ def create_dataset(
795796
try:
796797
dataset = self.get_dataset(name)
797798
default_version = dataset.next_version_patch
799+
if update_version == "major":
800+
default_version = dataset.next_version_major
801+
if update_version == "minor":
802+
default_version = dataset.next_version_minor
803+
798804
if (description or attrs) and (
799805
dataset.description != description or dataset.attrs != attrs
800806
):

src/datachain/lib/dc/datachain.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ def save( # type: ignore[override]
461461
version: Optional[str] = None,
462462
description: Optional[str] = None,
463463
attrs: Optional[list[str]] = None,
464+
update_version: Optional[str] = "patch",
464465
**kwargs,
465466
) -> "Self":
466467
"""Save to a Dataset. It returns the chain itself.
@@ -472,10 +473,22 @@ def save( # type: ignore[override]
472473
description : description of a dataset.
473474
attrs : attributes of a dataset. They can be without value, e.g "NLP",
474475
or with a value, e.g "location=US".
476+
update_version: which part of the dataset version to automatically increase.
477+
Available values: `major`, `minor` or `patch`. Default is `patch`.
475478
"""
476479
if version is not None:
477480
semver.validate(version)
478481

482+
if update_version is not None and update_version not in [
483+
"patch",
484+
"major",
485+
"minor",
486+
]:
487+
raise ValueError(
488+
"update_version can have one of the following values: major, minor or"
489+
" patch"
490+
)
491+
479492
schema = self.signals_schema.clone_without_sys_signals().serialize()
480493
return self._evolve(
481494
query=self._query.save(
@@ -484,6 +497,7 @@ def save( # type: ignore[override]
484497
description=description,
485498
attrs=attrs,
486499
feature_schema=schema,
500+
update_version=update_version,
487501
**kwargs,
488502
)
489503
)

src/datachain/lib/dc/storage.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
Union,
66
)
77

8-
from datachain.error import DatasetNotFoundError
98
from datachain.lib.file import (
109
FileType,
1110
get_file_type,
@@ -132,11 +131,6 @@ def read_storage(
132131

133132
def lst_fn(ds_name, lst_uri):
134133
# disable prefetch for listing, as it pre-downloads all files
135-
try:
136-
version = catalog.get_dataset(ds_name).next_version_major
137-
except DatasetNotFoundError:
138-
version = None
139-
140134
(
141135
read_records(
142136
DataChain.DEFAULT_FILE_RECORD,
@@ -150,7 +144,7 @@ def lst_fn(ds_name, lst_uri):
150144
output={f"{column}": file_type},
151145
)
152146
# for internal listing datasets, we always bump major version
153-
.save(ds_name, listing=True, version=version)
147+
.save(ds_name, listing=True, update_version="major")
154148
)
155149

156150
dc._query.set_listing_fn(

src/datachain/query/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,7 @@ def save(
16891689
feature_schema: Optional[dict] = None,
16901690
description: Optional[str] = None,
16911691
attrs: Optional[list[str]] = None,
1692+
update_version: Optional[str] = "patch",
16921693
**kwargs,
16931694
) -> "Self":
16941695
"""Save the query as a dataset."""
@@ -1723,6 +1724,7 @@ def save(
17231724
columns=columns,
17241725
description=description,
17251726
attrs=attrs,
1727+
update_version=update_version,
17261728
**kwargs,
17271729
)
17281730
version = version or dataset.latest_version

tests/unit/lib/test_datachain.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,18 +3173,41 @@ def test_delete_dataset_from_studio_not_found(
31733173
assert str(exc_info.value) == error_message
31743174

31753175

3176+
@pytest.mark.parametrize(
3177+
"update_version,versions",
3178+
[
3179+
("patch", ["1.0.0", "1.0.1", "1.0.2"]),
3180+
("minor", ["1.0.0", "1.1.0", "1.2.0"]),
3181+
("major", ["1.0.0", "2.0.0", "3.0.0"]),
3182+
],
3183+
)
3184+
def test_update_versions(test_session, update_version, versions):
3185+
ds_name = "fibonacci"
3186+
chain = dc.read_values(fib=[1, 1, 2, 3, 5, 8], session=test_session)
3187+
chain.save(ds_name, update_version=update_version)
3188+
chain.save(ds_name, update_version=update_version)
3189+
chain.save(ds_name, update_version=update_version)
3190+
assert sorted(
3191+
[
3192+
ds.version
3193+
for ds in dc.datasets(column="dataset", session=test_session).collect(
3194+
"dataset"
3195+
)
3196+
]
3197+
) == sorted(versions)
3198+
3199+
31763200
def test_update_versions_mix_major_minor_patch(test_session):
31773201
ds_name = "fibonacci"
31783202
chain = dc.read_values(fib=[1, 1, 2, 3, 5, 8], session=test_session)
31793203
chain.save(ds_name)
3204+
chain.save(ds_name, update_version="patch")
3205+
chain.save(ds_name, update_version="minor")
3206+
chain.save(ds_name, update_version="major")
3207+
chain.save(ds_name, update_version="minor")
3208+
chain.save(ds_name, update_version="patch")
31803209
chain.save(ds_name)
3181-
chain.save(ds_name, version="1.1.0")
3182-
chain.save(ds_name)
3183-
chain.save(ds_name, version="2.0.0")
3184-
chain.save(ds_name)
3185-
chain.save(ds_name, version="2.1.0")
3186-
chain.save(ds_name)
3187-
chain.save(ds_name)
3210+
chain.save(ds_name, version="3.0.0")
31883211
assert sorted(
31893212
[
31903213
ds.version
@@ -3197,16 +3220,27 @@ def test_update_versions_mix_major_minor_patch(test_session):
31973220
"1.0.0",
31983221
"1.0.1",
31993222
"1.1.0",
3200-
"1.1.1",
32013223
"2.0.0",
3202-
"2.0.1",
32033224
"2.1.0",
32043225
"2.1.1",
32053226
"2.1.2",
3227+
"3.0.0",
32063228
]
32073229
)
32083230

32093231

3232+
def test_update_versions_wrong_value(test_session):
3233+
ds_name = "fibonacci"
3234+
chain = dc.read_values(fib=[1, 1, 2, 3, 5, 8], session=test_session)
3235+
chain.save(ds_name)
3236+
with pytest.raises(ValueError) as excinfo:
3237+
chain.save(ds_name, update_version="wrong")
3238+
3239+
assert str(excinfo.value) == (
3240+
"update_version can have one of the following values: major, minor or patch"
3241+
)
3242+
3243+
32103244
def test_from_dataset_version_int_backward_compatible(test_session):
32113245
ds_name = "numbers"
32123246
dc.read_values(nums=[1], session=test_session).save(ds_name, version="1.0.0")

0 commit comments

Comments
 (0)