Skip to content

Commit f186d41

Browse files
authored
Query performance optimizations: reduce eagerloading queries in /sql (#1824)
* Various query optimizations by reducing eager loading * Add additional test coverage
1 parent 8339b89 commit f186d41

File tree

21 files changed

+442
-97
lines changed

21 files changed

+442
-97
lines changed

datajunction-server/datajunction_server/api/graphql/dataloaders.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DataLoaders for batching and caching GraphQL queries.
33
"""
44

5+
import json
56
from typing import Any
67

78
from sqlalchemy import select
@@ -109,48 +110,52 @@ def create_node_by_name_loader(request: Request) -> DataLoader[str, DBNode | Non
109110

110111

111112
async def batch_load_collection_nodes(
112-
collection_ids: list[int],
113+
keys: list[tuple[int, str]],
113114
request: Request,
114115
) -> list[list[DBNode]]:
115116
"""
116-
Batch load nodes for multiple collections.
117+
Batch load nodes for multiple collections with field-aware eager loading.
117118
118-
This batches multiple collection node lookups into a single query,
119-
avoiding N+1 queries when fetching nodes for multiple collections.
119+
Keys are (collection_id, fields_json) tuples where fields_json is a
120+
JSON-serialized dict of requested GraphQL fields (for load_node_options).
120121
121122
Args:
122-
collection_ids: List of collection IDs
123+
keys: List of (collection_id, fields_json) tuples
123124
request: The Starlette request object for creating sessions
124125
125126
Returns:
126-
List of node lists, one per collection ID, in the same order
127+
List of node lists, one per key, in the same order
127128
"""
129+
collection_ids = [cid for cid, _ in keys]
130+
131+
# Merge all requested fields across all loaders in this batch
132+
all_fields: dict[str, Any] = {}
133+
for _, fields_json in keys:
134+
if fields_json: # pragma: no branch
135+
all_fields.update(json.loads(fields_json))
136+
128137
async with session_context(request) as session:
129-
# Load all requested collections with their nodes in one query
138+
node_options = load_node_options(all_fields)
130139
stmt = (
131140
select(DBCollection)
132141
.where(DBCollection.id.in_(collection_ids))
133-
.options(selectinload(DBCollection.nodes))
142+
.options(selectinload(DBCollection.nodes).options(*node_options))
134143
)
135144
result = await session.execute(stmt)
136145
collections = result.unique().scalars().all()
137146

138-
# Create a lookup map: collection_id -> nodes
139147
collection_nodes_map = {c.id: c.nodes for c in collections}
140-
141-
# Return node lists in the same order as requested collection IDs
142-
# Return empty list if collection not found
143148
return [collection_nodes_map.get(cid, []) for cid in collection_ids]
144149

145150

146151
def create_collection_nodes_loader(
147152
request: Request,
148-
) -> DataLoader[int, list[DBNode]]:
153+
) -> DataLoader[tuple[int, str], list[DBNode]]:
149154
"""
150155
Create a DataLoader for loading nodes by collection ID.
151156
152-
This loader batches multiple collection node lookups within a single request
153-
and caches the results to avoid N+1 queries.
157+
Keys are (collection_id, fields_json) tuples so the loader can
158+
eagerly load only the node relationships the query actually requests.
154159
155160
Args:
156161
request: The Starlette request object

datajunction-server/datajunction_server/api/graphql/scalars/collection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Collection GraphQL scalar types.
33
"""
44

5+
import json
56
from datetime import datetime
67
from typing import TYPE_CHECKING
78

@@ -10,6 +11,7 @@
1011

1112
from datajunction_server.api.graphql.scalars.node import Node
1213
from datajunction_server.api.graphql.scalars.user import User
14+
from datajunction_server.api.graphql.utils import extract_fields
1315

1416
if TYPE_CHECKING:
1517
from datajunction_server.database.collection import (
@@ -37,8 +39,8 @@ async def nodes(self, info: Info) -> list[Node]:
3739
Uses dataloader to batch requests efficiently.
3840
"""
3941
loader = info.context["collection_nodes_loader"]
40-
nodes = await loader.load(self.id)
41-
return nodes # type: ignore
42+
node_fields = extract_fields(info)
43+
return await loader.load((self.id, json.dumps(node_fields, sort_keys=True))) # type: ignore
4244

4345
@classmethod
4446
def from_db_collection(

datajunction-server/datajunction_server/api/helpers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,14 @@ async def validate_cube(
527527
message=("Metrics and dimensions must be part of a common catalog"),
528528
)
529529

530-
await validate_shared_dimensions(
531-
session,
532-
metric_nodes,
533-
dimension_names,
534-
)
530+
# Only validate shared dimensions if dimensions were actually requested
531+
# This avoids expensive dimension graph loading when dimensions=[]
532+
if dimension_names:
533+
await validate_shared_dimensions(
534+
session,
535+
metric_nodes,
536+
dimension_names,
537+
)
535538
return metrics, metric_nodes, list(dimension_nodes.values()), dimensions, catalog
536539

537540

datajunction-server/datajunction_server/api/sql.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -578,13 +578,23 @@ async def get_sql_for_metrics(
578578
"""
579579
Return SQL for a set of metrics with dimensions and filters
580580
"""
581-
# make sure all metrics exist and have correct node type
582-
nodes = [
583-
await Node.get_by_name(session, node, raise_if_not_exists=True)
584-
for node in metrics
585-
]
586-
non_metric_nodes = [node for node in nodes if node and node.type != NodeType.METRIC]
581+
# Label this session for debugging
582+
session.info["session_label"] = "initial node loading"
583+
584+
# Fetch all metric nodes in a single query (only name/type needed for validation here)
585+
nodes = await Node.get_by_names(session, metrics, options=[])
587586

587+
# Check if all requested nodes exist
588+
found_names = {node.name for node in nodes}
589+
missing_nodes = set(metrics) - found_names
590+
if missing_nodes:
591+
raise DJInvalidInputException(
592+
message=f"The following nodes do not exist: {', '.join(missing_nodes)}",
593+
http_status_code=HTTPStatus.NOT_FOUND,
594+
)
595+
596+
# Validate node types
597+
non_metric_nodes = [node for node in nodes if node and node.type != NodeType.METRIC]
588598
if non_metric_nodes:
589599
raise DJInvalidInputException(
590600
message="All nodes must be of metric type, but some are not: "
@@ -596,6 +606,7 @@ async def get_sql_for_metrics(
596606
cache=cache,
597607
query_type=QueryBuildType.METRICS,
598608
)
609+
599610
return await query_cache_manager.get_or_load(
600611
background_tasks,
601612
request,
@@ -611,4 +622,5 @@ async def get_sql_for_metrics(
611622
use_materialized=use_materialized,
612623
ignore_errors=ignore_errors,
613624
),
625+
session=session, # Pass the session to reuse it
614626
)

datajunction-server/datajunction_server/construction/build_v2.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from sqlalchemy import text, bindparam, select
1818
from sqlalchemy.ext.asyncio import AsyncSession
19-
from sqlalchemy.orm import joinedload, selectinload
19+
from sqlalchemy.orm import joinedload, selectinload, noload
2020

2121
from datajunction_server.internal.access.authorization import (
2222
AccessChecker,
@@ -723,6 +723,8 @@ async def find_join_paths_batch(
723723
724724
This is O(1) database calls instead of O(nodes * depth) individual queries.
725725
"""
726+
# Filter out empty strings and check if we have any valid dimension names
727+
target_dimension_names = {name for name in target_dimension_names if name}
726728
if not target_dimension_names:
727729
return {} # pragma: no cover
728730

@@ -800,18 +802,24 @@ async def load_dimension_links_and_nodes(
800802
.where(DimensionLink.id.in_(link_ids))
801803
.options(
802804
joinedload(DimensionLink.dimension).options(
805+
noload(Node.created_by),
803806
joinedload(Node.current).options(
807+
noload(NodeRevision.created_by),
804808
selectinload(NodeRevision.columns).options(
805809
joinedload(Column.attributes).joinedload(
806810
ColumnAttribute.attribute_type,
807811
),
808-
joinedload(Column.dimension),
812+
joinedload(Column.dimension).options(
813+
noload(Node.created_by),
814+
),
809815
joinedload(Column.partition),
810816
),
811817
joinedload(NodeRevision.catalog),
812818
selectinload(NodeRevision.availability),
813819
selectinload(NodeRevision.dimension_links).options(
814-
joinedload(DimensionLink.dimension),
820+
joinedload(DimensionLink.dimension).options(
821+
noload(Node.created_by),
822+
),
815823
),
816824
),
817825
),
@@ -1323,6 +1331,8 @@ async def build(self) -> ast.Query:
13231331
Builds SQL for multiple metrics with the requested set of dimensions,
13241332
filter expressions, order by, and limit clauses.
13251333
"""
1334+
# Always add dimensions referenced in the metric queries themselves
1335+
# (e.g., if a metric references a joinable dimension in its SQL definition)
13261336
self.add_dimensions(get_dimensions_referenced_in_metrics(self.metric_nodes))
13271337

13281338
measures_queries = await self.build_measures_queries()

datajunction-server/datajunction_server/construction/build_v3/cube_matcher.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from sqlalchemy import and_, select
1515
from sqlalchemy.ext.asyncio import AsyncSession
16-
from sqlalchemy.orm import joinedload, selectinload
16+
from sqlalchemy.orm import joinedload, selectinload, noload
1717

1818
from datajunction_server.construction.build_v3.decomposition import is_derived_metric
1919
from datajunction_server.models.dialect import Dialect
@@ -87,9 +87,13 @@ async def find_matching_cube(
8787
),
8888
)
8989
.options(
90+
noload(Node.created_by), # Prevent User N+1 queries
9091
joinedload(Node.current).options(
91-
selectinload(NodeRevision.cube_elements).selectinload(
92-
Column.node_revision,
92+
noload(NodeRevision.created_by), # Prevent User N+1 queries
93+
selectinload(NodeRevision.cube_elements).options(
94+
selectinload(Column.node_revision).options(
95+
noload(NodeRevision.created_by), # Prevent User N+1 queries
96+
),
9397
),
9498
joinedload(NodeRevision.availability),
9599
selectinload(NodeRevision.materializations),
@@ -225,7 +229,18 @@ async def resolve_dialect_and_engine_for_metrics(
225229
)
226230

227231
# Fallback: use first metric's catalog's default engine
228-
node = await Node.get_by_name(session, metrics[0], raise_if_not_exists=True)
232+
node = await Node.get_by_name(
233+
session,
234+
metrics[0],
235+
raise_if_not_exists=True,
236+
options=[
237+
joinedload(Node.current).options(
238+
noload(NodeRevision.created_by), # Prevent User N+1 queries
239+
joinedload(NodeRevision.catalog),
240+
),
241+
noload(Node.created_by), # Prevent User N+1 queries
242+
],
243+
)
229244
if not node: # pragma: no cover
230245
raise ValueError(f"Metric not found: {metrics[0]}")
231246

datajunction-server/datajunction_server/construction/build_v3/loaders.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from sqlalchemy import select, text, bindparam
1111
from sqlalchemy.ext.asyncio import AsyncSession
12-
from sqlalchemy.orm import selectinload, joinedload, load_only
12+
from sqlalchemy.orm import selectinload, joinedload, load_only, noload
1313

1414
from datajunction_server.database.dimensionlink import DimensionLink
1515
from datajunction_server.database.node import Node, NodeRevision, Column
@@ -275,7 +275,9 @@ async def load_dimension_links_batch(
275275
.where(DimensionLink.id.in_(link_ids))
276276
.options(
277277
joinedload(DimensionLink.dimension).options(
278+
noload(Node.created_by), # Prevent User N+1 queries
278279
joinedload(Node.current).options(
280+
noload(NodeRevision.created_by), # Prevent User N+1 queries
279281
# Load what's needed for table references, parsing, and type lookups
280282
joinedload(NodeRevision.catalog),
281283
joinedload(NodeRevision.availability),
@@ -375,6 +377,7 @@ async def load_nodes(ctx: BuildContext) -> None:
375377
Node.current_version,
376378
),
377379
joinedload(Node.current).options(
380+
noload(NodeRevision.created_by), # Prevent User N+1 queries
378381
load_only(
379382
NodeRevision.name,
380383
NodeRevision.query,
@@ -391,13 +394,18 @@ async def load_nodes(ctx: BuildContext) -> None:
391394
selectinload(NodeRevision.required_dimensions).options(
392395
# Load the node_revision and node to reconstruct full dimension path
393396
joinedload(Column.node_revision).options(
394-
joinedload(NodeRevision.node),
397+
noload(NodeRevision.created_by), # Prevent User N+1 queries
398+
joinedload(NodeRevision.node).options(
399+
noload(Node.created_by), # Prevent User N+1 queries
400+
),
395401
),
396402
),
397403
joinedload(NodeRevision.availability), # For materialization support
398404
selectinload(NodeRevision.dimension_links).options(
399405
# Load dimension node for link matching in temporal filters
400-
joinedload(DimensionLink.dimension),
406+
joinedload(DimensionLink.dimension).options(
407+
noload(Node.created_by), # Prevent User N+1 queries
408+
),
401409
),
402410
),
403411
)

datajunction-server/datajunction_server/database/node.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Mapped,
3232
joinedload,
3333
mapped_column,
34+
noload,
3435
relationship,
3536
selectinload,
3637
MappedColumn,
@@ -309,6 +310,7 @@ class Node(Base):
309310
secondary="tagnoderelationship",
310311
primaryjoin="TagNodeRelationship.node_id==Node.id",
311312
secondaryjoin="TagNodeRelationship.tag_id==Tag.id",
313+
lazy="selectin",
312314
)
313315

314316
namespace_obj: Mapped[Optional["NodeNamespace"]] = relationship(
@@ -550,9 +552,6 @@ async def get_by_name(
550552
joinedload(Node.current).options(
551553
*NodeRevision.default_load_options(),
552554
),
553-
selectinload(Node.tags),
554-
selectinload(Node.created_by),
555-
selectinload(Node.owners),
556555
]
557556
statement = statement.options(*options)
558557
if not include_inactive:
@@ -581,7 +580,12 @@ async def get_by_names(
581580
"""
582581
Get nodes by names
583582
"""
583+
# Early return if no names provided to avoid useless query
584+
if not names:
585+
return []
586+
584587
statement = select(Node).where(Node.name.in_(names))
588+
585589
options = options or [
586590
joinedload(Node.current).options(
587591
*NodeRevision.default_load_options(),
@@ -1081,7 +1085,7 @@ class NodeRevision(
10811085
secondary="cube",
10821086
primaryjoin="NodeRevision.id==CubeRelationship.cube_id",
10831087
secondaryjoin="Column.id==CubeRelationship.cube_element_id",
1084-
lazy="selectin",
1088+
# No lazy strategy - control via options (selectinload or noload)
10851089
order_by="Column.order",
10861090
)
10871091

@@ -1188,22 +1192,34 @@ def default_load_options(cls):
11881192
joinedload(Column.attributes).joinedload(
11891193
ColumnAttribute.attribute_type,
11901194
),
1191-
joinedload(Column.dimension),
1195+
joinedload(Column.dimension).options(
1196+
noload(Node.created_by),
1197+
),
11921198
joinedload(Column.partition),
11931199
),
11941200
joinedload(NodeRevision.catalog),
1195-
selectinload(NodeRevision.parents),
1201+
selectinload(NodeRevision.parents).options(
1202+
selectinload(Node.current).options(
1203+
noload(NodeRevision.created_by),
1204+
),
1205+
noload(Node.created_by),
1206+
),
11961207
selectinload(NodeRevision.materializations),
11971208
selectinload(NodeRevision.metric_metadata),
11981209
selectinload(NodeRevision.availability),
11991210
selectinload(NodeRevision.dimension_links).options(
12001211
joinedload(DimensionLink.dimension).options(
1201-
selectinload(Node.current),
1212+
selectinload(Node.current).options(
1213+
noload(NodeRevision.created_by),
1214+
),
1215+
noload(Node.created_by),
12021216
),
12031217
joinedload(DimensionLink.node_revision),
12041218
),
12051219
selectinload(NodeRevision.required_dimensions),
12061220
selectinload(NodeRevision.availability),
1221+
# Load created_by for API responses (but noload in /sql/ endpoint's custom options)
1222+
selectinload(NodeRevision.created_by),
12071223
)
12081224

12091225
@classmethod

0 commit comments

Comments
 (0)