Skip to content

Commit bf78a4e

Browse files
authored
Refactor dataset dependency query to make select column configurable (#1405)
* Refactor dataset dependency query to make select column configurable # Changes: - Added new abstract method `_dataset_dependency_nodes_select_columns()` to make the column selection for dataset dependency queries more maintainable and extensible across different database backends - Introduced a depth limit of 100 in the recursive CTE to prevent infinite loops * Move to constant
1 parent 654d9af commit bf78a4e

File tree

2 files changed

+60
-24
lines changed

2 files changed

+60
-24
lines changed

src/datachain/data_storage/metastore.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@
5656
from datachain.utils import JSONSerialize
5757

5858
if TYPE_CHECKING:
59-
from sqlalchemy import Delete, Insert, Select, Update
59+
from sqlalchemy import CTE, Delete, Insert, Select, Subquery, Update
6060
from sqlalchemy.schema import SchemaItem
61+
from sqlalchemy.sql.elements import ColumnElement
6162

6263
from datachain.data_storage import schema
6364
from datachain.data_storage.db_engine import DatabaseEngine
6465

6566
logger = logging.getLogger("datachain")
67+
DEPTH_LIMIT_DEFAULT = 100
6668

6769

6870
class AbstractMetastore(ABC, Serializable):
@@ -1463,6 +1465,18 @@ def _dataset_dependencies_select_columns(self) -> list["SchemaItem"]:
14631465
Returns a list of columns to select in a query for fetching dataset dependencies
14641466
"""
14651467

1468+
@abstractmethod
1469+
def _dataset_dependency_nodes_select_columns(
1470+
self,
1471+
namespaces_subquery: "Subquery",
1472+
dependency_tree_cte: "CTE",
1473+
datasets_subquery: "Subquery",
1474+
) -> list["ColumnElement"]:
1475+
"""
1476+
Returns a list of columns to select in a query for fetching
1477+
dataset dependency nodes.
1478+
"""
1479+
14661480
def get_direct_dataset_dependencies(
14671481
self, dataset: DatasetRecord, version: str
14681482
) -> list[DatasetDependency | None]:
@@ -1493,7 +1507,7 @@ def get_direct_dataset_dependencies(
14931507
return [self.dependency_class.parse(*r) for r in self.db.execute(query)]
14941508

14951509
def get_dataset_dependency_nodes(
1496-
self, dataset_id: int, version_id: int
1510+
self, dataset_id: int, version_id: int, depth_limit: int = DEPTH_LIMIT_DEFAULT
14971511
) -> list[DatasetDependencyNode | None]:
14981512
n = self._namespaces_select().subquery()
14991513
p = self._projects
@@ -1522,33 +1536,31 @@ def get_dataset_dependency_nodes(
15221536
cte = base_query.cte(name="dependency_tree", recursive=True)
15231537

15241538
# Recursive case: dependencies of dependencies
1525-
recursive_query = select(
1526-
*dep_fields,
1527-
(cte.c.depth + 1).label("depth"),
1528-
).select_from(
1529-
cte.join(
1530-
dd,
1531-
(cte.c.dataset_id == dd.c.source_dataset_id)
1532-
& (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
1539+
# Limit depth to 100 to prevent infinite loops in case of circular dependencies
1540+
recursive_query = (
1541+
select(
1542+
*dep_fields,
1543+
(cte.c.depth + 1).label("depth"),
15331544
)
1545+
.select_from(
1546+
cte.join(
1547+
dd,
1548+
(cte.c.dataset_id == dd.c.source_dataset_id)
1549+
& (cte.c.dataset_version_id == dd.c.source_dataset_version_id),
1550+
)
1551+
)
1552+
.where(cte.c.depth < depth_limit)
15341553
)
15351554

15361555
cte = cte.union(recursive_query)
15371556

15381557
# Fetch all with full details
1539-
final_query = select(
1540-
n.c.name,
1541-
p.c.name,
1542-
cte.c.id,
1543-
cte.c.dataset_id,
1544-
cte.c.dataset_version_id,
1545-
d.c.name,
1546-
dv.c.version,
1547-
dv.c.created_at,
1548-
cte.c.source_dataset_id,
1549-
cte.c.source_dataset_version_id,
1550-
cte.c.depth,
1551-
).select_from(
1558+
select_cols = self._dataset_dependency_nodes_select_columns(
1559+
namespaces_subquery=n,
1560+
dependency_tree_cte=cte,
1561+
datasets_subquery=d,
1562+
)
1563+
final_query = self._datasets_dependencies_select(*select_cols).select_from(
15521564
# Use outer joins to handle cases where dependent datasets have been
15531565
# physically deleted. This allows us to return dependency records with
15541566
# None values instead of silently omitting them, making broken

src/datachain/data_storage/sqlite.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from sqlalchemy.dialects import sqlite
2121
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
2222
from sqlalchemy.sql import func
23-
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
23+
from sqlalchemy.sql.elements import (
24+
BinaryExpression,
25+
BooleanClauseList,
26+
)
2427
from sqlalchemy.sql.expression import bindparam, cast
2528
from sqlalchemy.sql.selectable import Select
2629
from tqdm.auto import tqdm
@@ -41,6 +44,7 @@
4144
from datachain.utils import DataChainDir, batched, batched_it
4245

4346
if TYPE_CHECKING:
47+
from sqlalchemy import CTE, Subquery
4448
from sqlalchemy.dialects.sqlite import Insert
4549
from sqlalchemy.engine.base import Engine
4650
from sqlalchemy.schema import SchemaItem
@@ -539,6 +543,26 @@ def _dataset_dependencies_select_columns(self) -> list["SchemaItem"]:
539543
self._datasets_versions.c.created_at,
540544
]
541545

546+
def _dataset_dependency_nodes_select_columns(
547+
self,
548+
namespaces_subquery: "Subquery",
549+
dependency_tree_cte: "CTE",
550+
datasets_subquery: "Subquery",
551+
) -> list["ColumnElement"]:
552+
return [
553+
namespaces_subquery.c.name,
554+
self._projects.c.name,
555+
dependency_tree_cte.c.id,
556+
dependency_tree_cte.c.dataset_id,
557+
dependency_tree_cte.c.dataset_version_id,
558+
datasets_subquery.c.name,
559+
self._datasets_versions.c.version,
560+
self._datasets_versions.c.created_at,
561+
dependency_tree_cte.c.source_dataset_id,
562+
dependency_tree_cte.c.source_dataset_version_id,
563+
dependency_tree_cte.c.depth,
564+
]
565+
542566
#
543567
# Jobs
544568
#

0 commit comments

Comments
 (0)