Skip to content

Commit 0e27a8e

Browse files
authored
Alternative to #1387 (#1395)
* Optimize dataset dependency calculation and storage Summary: ------ This PR changes improvement in performances to dataset dependency calculation as well as in storage where we now store the nested dependency structure directly to optimize the retreival process. Changes: ------ - Replaced recursive database queries with a single batch query that fetches all dependencies at once - Reduced Database Round-trips: Dependency calculation now uses get_direct_dataset_dependencies_by_ids() to fetch all required dependencies in one query - Optimized Memory Usage: New DatasetDependencyMinimal class stores only essential dependency information Database change: ------- - Added nested_dependencies Column: New JSON column in datasets_dependencies table to store pre-calculated dependency structures - Backward Compatibility: Schema migration safely adds the new column to existing databases - Efficient Storage: Nested dependency trees are stored as JSON, reducing storage overhead * Address comments from sourcery * Fix dataset query tests * Add proper parsing for json * Retain old none semantics * Fix the execute_str * Fix sql injection * Clarify stuffs further * Try out recursive approach * Cleanup * Use namespace select and projects select * Fix deleted dependency part * Reorder and refactor
1 parent a783384 commit 0e27a8e

File tree

3 files changed

+289
-20
lines changed

3 files changed

+289
-20
lines changed

src/datachain/catalog/catalog.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from datachain.utils import DataChainDir
5555

5656
from .datasource import DataSource
57+
from .dependency import build_dependency_hierarchy, populate_nested_dependencies
5758

5859
if TYPE_CHECKING:
5960
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
@@ -1203,6 +1204,38 @@ def get_remote_dataset(
12031204
assert isinstance(dataset_info, dict)
12041205
return DatasetRecord.from_dict(dataset_info)
12051206

1207+
def get_dataset_dependencies_by_ids(
1208+
self,
1209+
dataset_id: int,
1210+
version_id: int,
1211+
indirect: bool = True,
1212+
) -> list[DatasetDependency | None]:
1213+
dependency_nodes = self.metastore.get_dataset_dependency_nodes(
1214+
dataset_id=dataset_id,
1215+
version_id=version_id,
1216+
)
1217+
1218+
if not dependency_nodes:
1219+
return []
1220+
1221+
dependency_map, children_map = build_dependency_hierarchy(dependency_nodes)
1222+
1223+
root_key = (dataset_id, version_id)
1224+
if root_key not in children_map:
1225+
return []
1226+
1227+
root_dependency_ids = children_map[root_key]
1228+
root_dependencies = [dependency_map[dep_id] for dep_id in root_dependency_ids]
1229+
1230+
if indirect:
1231+
for dependency in root_dependencies:
1232+
if dependency is not None:
1233+
populate_nested_dependencies(
1234+
dependency, dependency_nodes, dependency_map, children_map
1235+
)
1236+
1237+
return root_dependencies
1238+
12061239
def get_dataset_dependencies(
12071240
self,
12081241
name: str,
@@ -1216,29 +1249,21 @@ def get_dataset_dependencies(
12161249
namespace_name=namespace_name,
12171250
project_name=project_name,
12181251
)
1219-
1220-
direct_dependencies = self.metastore.get_direct_dataset_dependencies(
1221-
dataset, version
1222-
)
1252+
dataset_version = dataset.get_version(version)
1253+
dataset_id = dataset.id
1254+
dataset_version_id = dataset_version.id
12231255

12241256
if not indirect:
1225-
return direct_dependencies
1226-
1227-
for d in direct_dependencies:
1228-
if not d:
1229-
# dependency has been removed
1230-
continue
1231-
if d.is_dataset:
1232-
# only datasets can have dependencies
1233-
d.dependencies = self.get_dataset_dependencies(
1234-
d.name,
1235-
d.version,
1236-
namespace_name=d.namespace,
1237-
project_name=d.project,
1238-
indirect=indirect,
1239-
)
1257+
return self.metastore.get_direct_dataset_dependencies(
1258+
dataset,
1259+
version,
1260+
)
12401261

1241-
return direct_dependencies
1262+
return self.get_dataset_dependencies_by_ids(
1263+
dataset_id,
1264+
dataset_version_id,
1265+
indirect,
1266+
)
12421267

12431268
def ls_datasets(
12441269
self,
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import builtins
2+
from dataclasses import dataclass
3+
from datetime import datetime
4+
from typing import TypeVar
5+
6+
from datachain.dataset import DatasetDependency
7+
8+
DDN = TypeVar("DDN", bound="DatasetDependencyNode")
9+
10+
11+
@dataclass
12+
class DatasetDependencyNode:
13+
namespace: str
14+
project: str
15+
id: int
16+
dataset_id: int | None
17+
dataset_version_id: int | None
18+
dataset_name: str | None
19+
dataset_version: str | None
20+
created_at: datetime
21+
source_dataset_id: int
22+
source_dataset_version_id: int | None
23+
depth: int
24+
25+
@classmethod
26+
def parse(
27+
cls: builtins.type[DDN],
28+
namespace: str,
29+
project: str,
30+
id: int,
31+
dataset_id: int | None,
32+
dataset_version_id: int | None,
33+
dataset_name: str | None,
34+
dataset_version: str | None,
35+
created_at: datetime,
36+
source_dataset_id: int,
37+
source_dataset_version_id: int | None,
38+
depth: int,
39+
) -> "DatasetDependencyNode | None":
40+
return cls(
41+
namespace,
42+
project,
43+
id,
44+
dataset_id,
45+
dataset_version_id,
46+
dataset_name,
47+
dataset_version,
48+
created_at,
49+
source_dataset_id,
50+
source_dataset_version_id,
51+
depth,
52+
)
53+
54+
def to_dependency(self) -> "DatasetDependency | None":
55+
return DatasetDependency.parse(
56+
namespace_name=self.namespace,
57+
project_name=self.project,
58+
id=self.id,
59+
dataset_id=self.dataset_id,
60+
dataset_version_id=self.dataset_version_id,
61+
dataset_name=self.dataset_name,
62+
dataset_version=self.dataset_version,
63+
dataset_version_created_at=self.created_at,
64+
)
65+
66+
67+
def build_dependency_hierarchy(
68+
dependency_nodes: list[DatasetDependencyNode | None],
69+
) -> tuple[
70+
dict[int, DatasetDependency | None], dict[tuple[int, int | None], list[int]]
71+
]:
72+
"""
73+
Build dependency hierarchy from dependency nodes.
74+
75+
Args:
76+
dependency_nodes: List of DatasetDependencyNode objects from the database
77+
78+
Returns:
79+
Tuple of (dependency_map, children_map) where:
80+
- dependency_map: Maps dependency_id -> DatasetDependency
81+
- children_map: Maps (source_dataset_id, source_version_id) ->
82+
list of dependency_ids
83+
"""
84+
dependency_map: dict[int, DatasetDependency | None] = {}
85+
children_map: dict[tuple[int, int | None], list[int]] = {}
86+
87+
for node in dependency_nodes:
88+
if node is None:
89+
continue
90+
dependency = node.to_dependency()
91+
parent_key = (node.source_dataset_id, node.source_dataset_version_id)
92+
93+
if dependency is not None:
94+
dependency_map[dependency.id] = dependency
95+
children_map.setdefault(parent_key, []).append(dependency.id)
96+
else:
97+
# Handle case where dependency creation failed (e.g., deleted dependency)
98+
dependency_map[node.id] = None
99+
children_map.setdefault(parent_key, []).append(node.id)
100+
101+
return dependency_map, children_map
102+
103+
104+
def populate_nested_dependencies(
105+
dependency: DatasetDependency,
106+
dependency_nodes: list[DatasetDependencyNode | None],
107+
dependency_map: dict[int, DatasetDependency | None],
108+
children_map: dict[tuple[int, int | None], list[int]],
109+
) -> None:
110+
"""
111+
Recursively populate nested dependencies for a given dependency.
112+
113+
Args:
114+
dependency: The dependency to populate nested dependencies for
115+
dependency_nodes: All dependency nodes from the database
116+
dependency_map: Maps dependency_id -> DatasetDependency
117+
children_map: Maps (source_dataset_id, source_version_id) ->
118+
list of dependency_ids
119+
"""
120+
# Find the target dataset and version for this dependency
121+
target_dataset_id, target_version_id = find_target_dataset_version(
122+
dependency, dependency_nodes
123+
)
124+
125+
if target_dataset_id is None or target_version_id is None:
126+
return
127+
128+
# Get children for this target
129+
target_key = (target_dataset_id, target_version_id)
130+
if target_key not in children_map:
131+
dependency.dependencies = []
132+
return
133+
134+
child_dependency_ids = children_map[target_key]
135+
child_dependencies = [dependency_map[child_id] for child_id in child_dependency_ids]
136+
137+
dependency.dependencies = child_dependencies
138+
139+
# Recursively populate children
140+
for child_dependency in child_dependencies:
141+
if child_dependency is not None:
142+
populate_nested_dependencies(
143+
child_dependency, dependency_nodes, dependency_map, children_map
144+
)
145+
146+
147+
def find_target_dataset_version(
148+
dependency: DatasetDependency,
149+
dependency_nodes: list[DatasetDependencyNode | None],
150+
) -> tuple[int | None, int | None]:
151+
"""
152+
Find the target dataset ID and version ID for a given dependency.
153+
154+
Args:
155+
dependency: The dependency to find target for
156+
dependency_nodes: All dependency nodes from the database
157+
158+
Returns:
159+
Tuple of (target_dataset_id, target_version_id) or (None, None) if not found
160+
"""
161+
for node in dependency_nodes:
162+
if node is not None and node.id == dependency.id:
163+
return node.dataset_id, node.dataset_version_id
164+
return None, None

src/datachain/data_storage/metastore.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
Text,
2323
UniqueConstraint,
2424
desc,
25+
literal,
2526
select,
2627
)
2728
from sqlalchemy.sql import func as f
2829

30+
from datachain.catalog.dependency import DatasetDependencyNode
2931
from datachain.checkpoint import Checkpoint
3032
from datachain.data_storage import JobQueryType, JobStatus
3133
from datachain.data_storage.serializer import Serializable
@@ -78,6 +80,7 @@ class AbstractMetastore(ABC, Serializable):
7880
dataset_list_class: type[DatasetListRecord] = DatasetListRecord
7981
dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
8082
dependency_class: type[DatasetDependency] = DatasetDependency
83+
dependency_node_class: type[DatasetDependencyNode] = DatasetDependencyNode
8184
job_class: type[Job] = Job
8285
checkpoint_class: type[Checkpoint] = Checkpoint
8386

@@ -366,6 +369,12 @@ def get_direct_dataset_dependencies(
366369
) -> list[DatasetDependency | None]:
367370
"""Gets direct dataset dependencies."""
368371

372+
@abstractmethod
373+
def get_dataset_dependency_nodes(
374+
self, dataset_id: int, version_id: int
375+
) -> list[DatasetDependencyNode | None]:
376+
"""Gets dataset dependency node from database."""
377+
369378
@abstractmethod
370379
def remove_dataset_dependencies(
371380
self, dataset: DatasetRecord, version: str | None = None
@@ -1483,6 +1492,77 @@ def get_direct_dataset_dependencies(
14831492

14841493
return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
14851494

1495+
def get_dataset_dependency_nodes(
1496+
self, dataset_id: int, version_id: int
1497+
) -> list[DatasetDependencyNode | None]:
1498+
n = self._namespaces_select().subquery()
1499+
p = self._projects
1500+
d = self._datasets_select().subquery()
1501+
dd = self._datasets_dependencies
1502+
dv = self._datasets_versions
1503+
1504+
# Common dependency fields for CTE
1505+
dep_fields = [
1506+
dd.c.id,
1507+
dd.c.source_dataset_id,
1508+
dd.c.source_dataset_version_id,
1509+
dd.c.dataset_id,
1510+
dd.c.dataset_version_id,
1511+
]
1512+
1513+
# Base case: direct dependencies
1514+
base_query = select(
1515+
*dep_fields,
1516+
literal(0).label("depth"),
1517+
).where(
1518+
(dd.c.source_dataset_id == dataset_id)
1519+
& (dd.c.source_dataset_version_id == version_id)
1520+
)
1521+
1522+
cte = base_query.cte(name="dependency_tree", recursive=True)
1523+
1524+
# Recursive case: dependencies of dependencies
1525+
recursive_query = select(
1526+
*dep_fields,
1527+
(cte.c.depth + 1).label("depth"),
1528+
).select_from(
1529+
cte.join(
1530+
dd,
1531+
(cte.c.dataset_id == dd.c.source_dataset_id)
1532+
& (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
1533+
)
1534+
)
1535+
1536+
cte = cte.union(recursive_query)
1537+
1538+
# Fetch all with full details
1539+
final_query = select(
1540+
n.c.name,
1541+
p.c.name,
1542+
cte.c.id,
1543+
cte.c.dataset_id,
1544+
cte.c.dataset_version_id,
1545+
d.c.name,
1546+
dv.c.version,
1547+
dv.c.created_at,
1548+
cte.c.source_dataset_id,
1549+
cte.c.source_dataset_version_id,
1550+
cte.c.depth,
1551+
).select_from(
1552+
# Use outer joins to handle cases where dependent datasets have been
1553+
# physically deleted. This allows us to return dependency records with
1554+
# None values instead of silently omitting them, making broken
1555+
# dependencies visible to callers.
1556+
cte.join(d, cte.c.dataset_id == d.c.id, isouter=True)
1557+
.join(dv, cte.c.dataset_version_id == dv.c.id, isouter=True)
1558+
.join(p, d.c.project_id == p.c.id, isouter=True)
1559+
.join(n, p.c.namespace_id == n.c.id, isouter=True)
1560+
)
1561+
1562+
return [
1563+
self.dependency_node_class.parse(*r) for r in self.db.execute(final_query)
1564+
]
1565+
14861566
def remove_dataset_dependencies(
14871567
self, dataset: DatasetRecord, version: str | None = None
14881568
) -> None:

0 commit comments

Comments
 (0)