Skip to content

Commit 7865f89

Browse files
authored
fixing dataset dependencies when persist is used (#1112)
1 parent 30bf991 commit 7865f89

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/datachain/query/dataset.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,13 +1675,27 @@ def generate(
16751675
return query
16761676

16771677
def _add_dependencies(self, dataset: "DatasetRecord", version: str):
1678-
for dependency in self.dependencies:
1679-
ds_dependency_name, ds_dependency_version = dependency
1678+
dependencies: set[DatasetDependencyType] = set()
1679+
for dep_name, dep_version in self.dependencies:
1680+
if Session.is_temp_dataset(dep_name):
1681+
# temp dataset are created for optimization and they will be removed
1682+
# afterwards. Therefore, we should not put them as dependencies, but
1683+
# their own direct dependencies
1684+
for dep in self.catalog.get_dataset_dependencies(
1685+
dep_name, dep_version, indirect=False
1686+
):
1687+
if dep:
1688+
dependencies.add((dep.name, dep.version))
1689+
else:
1690+
dependencies.add((dep_name, dep_version))
1691+
1692+
for dep_name, dep_version in dependencies:
1693+
# ds_dependency_name, ds_dependency_version = dependency
16801694
self.catalog.metastore.add_dataset_dependency(
16811695
dataset.name,
16821696
version,
1683-
ds_dependency_name,
1684-
ds_dependency_version,
1697+
dep_name,
1698+
dep_version,
16851699
)
16861700

16871701
def exec(self) -> "Self":

tests/func/test_datachain.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,25 @@ def test_read_storage_dependencies(cloud_test_catalog, cloud_type):
239239
assert dependencies[0].name == dep_name
240240

241241

242+
def test_persist_not_affects_dependencies(tmp_dir, test_session):
243+
for i in range(4):
244+
(tmp_dir / f"file{i}.txt").write_text(f"file{i}")
245+
246+
uri = tmp_dir.as_uri()
247+
dep_name, _, _ = parse_listing_uri(uri, test_session.catalog.client_config)
248+
chain = dc.read_storage(uri, session=test_session) # .persist()
249+
# calling multiple persists to create temp datasets
250+
chain = chain.persist()
251+
chain = chain.persist()
252+
chain = chain.persist()
253+
chain.save("test-data")
254+
dependencies = test_session.catalog.get_dataset_dependencies("test-data", "1.0.0")
255+
256+
assert len(dependencies) == 1
257+
assert dependencies[0].name == dep_name
258+
assert dependencies[0].type == DatasetDependencyType.STORAGE
259+
260+
242261
@pytest.mark.parametrize("use_cache", [True, False])
243262
@pytest.mark.parametrize("prefetch", [0, 2])
244263
def test_map_file(cloud_test_catalog, use_cache, prefetch, monkeypatch):

0 commit comments

Comments
 (0)