Skip to content

Commit 9e8fc80

Browse files
authored
Modify GraphQL upstream/downstream endpoints to use dynamic field loading, which reduces database load for GraphQL requests that only need basic fields. (#1605)
1 parent c4f9b19 commit 9e8fc80

File tree

5 files changed

+194
-11
lines changed

5 files changed

+194
-11
lines changed

datajunction-server/datajunction_server/api/graphql/queries/dag.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from strawberry.types import Info
99

1010
from datajunction_server.database.node import Node
11-
from datajunction_server.api.graphql.resolvers.nodes import find_nodes_by
11+
from datajunction_server.api.graphql.resolvers.nodes import (
12+
find_nodes_by,
13+
load_node_options,
14+
)
1215
from datajunction_server.api.graphql.scalars.node import DimensionAttribute
16+
from datajunction_server.api.graphql.utils import extract_fields
1317
from datajunction_server.sql.dag import (
1418
get_common_dimensions,
1519
get_downstream_nodes,
@@ -75,13 +79,19 @@ async def downstream_nodes(
7579
fanout threshold check and BFS fallback work better with single nodes.
7680
"""
7781
session = info.context["session"]
82+
83+
# Build load options based on requested GraphQL fields
84+
fields = extract_fields(info)
85+
options = load_node_options(fields)
86+
7887
all_downstreams: dict[int, Node] = {}
7988
for node_name in node_names:
8089
downstreams = await get_downstream_nodes(
8190
session,
8291
node_name=node_name,
8392
node_type=node_type,
8493
include_deactivated=include_deactivated,
94+
options=options,
8595
)
8696
for node in downstreams:
8797
if node.id not in all_downstreams: # pragma: no cover
@@ -116,9 +126,15 @@ async def upstream_nodes(
116126
Results are deduplicated by node ID.
117127
"""
118128
session = info.context["session"]
129+
130+
# Build load options based on requested GraphQL fields
131+
fields = extract_fields(info)
132+
options = load_node_options(fields)
133+
119134
return await get_upstream_nodes( # type: ignore
120135
session,
121136
node_name=node_names,
122137
node_type=node_type,
123138
include_deactivated=include_deactivated,
139+
options=options,
124140
)

datajunction-server/datajunction_server/api/helpers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,17 @@ async def get_node_by_name(
104104
"""
105105
Get a node by name
106106
"""
107+
from datajunction_server.models.node import NodeOutput
108+
107109
statement = select(Node).where(Node.name == name)
108110
if not include_inactive:
109111
statement = statement.where(is_(Node.deactivated_at, None))
110112
if node_type:
111113
statement = statement.where(Node.type == node_type)
112114
if with_current:
113-
statement = statement.options(joinedload(Node.current)).options(
114-
joinedload(Node.tags),
115-
)
115+
# Use full NodeOutput load options to ensure all required fields
116+
# (like dimension_links) are eagerly loaded for serialization
117+
statement = statement.options(*NodeOutput.load_options())
116118
result = await session.execute(statement)
117119
node = result.unique().scalar_one_or_none()
118120
else:

datajunction-server/datajunction_server/sql/dag.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, List, Tuple, Union, cast
99

1010
from sqlalchemy import and_, func, join, literal, or_, select, distinct
11+
from sqlalchemy.sql.base import ExecutableOption
1112
from sqlalchemy.ext.asyncio import AsyncSession
1213
from sqlalchemy.orm import aliased, joinedload, selectinload
1314
from sqlalchemy.sql.operators import is_
@@ -65,15 +66,20 @@ async def get_downstream_nodes(
6566
include_deactivated: bool = True,
6667
include_cubes: bool = True,
6768
depth: int = -1,
68-
) -> List[Node]:
69+
options: list[ExecutableOption] = None,
70+
) -> list[Node]:
6971
"""
7072
Gets all downstream children of the given node, filterable by node type.
7173
Uses a recursive CTE query to build out all descendants from the node.
7274
"""
75+
# Use full options if none provided (for REST API DAGNodeOutput compatibility)
76+
result_options = options if options is not None else _node_output_options()
77+
78+
# Initial lookup always uses light options (only need node.id)
7379
node = await Node.get_by_name(
7480
session,
7581
node_name,
76-
options=_node_output_options(),
82+
options=[joinedload(Node.current)],
7783
)
7884
if not node:
7985
return []
@@ -160,7 +166,7 @@ async def get_downstream_nodes(
160166
final_select = final_select.where(max_depths.c.max_depth < depth)
161167

162168
statement = final_select.order_by(max_depths.c.max_depth, Node.id).options(
163-
*_node_output_options(),
169+
*result_options,
164170
)
165171
results = (await session.execute(statement)).unique().scalars().all()
166172
return [
@@ -308,6 +314,7 @@ async def get_upstream_nodes(
308314
node_name: Union[str, List[str]],
309315
node_type: NodeType = None,
310316
include_deactivated: bool = True,
317+
options: List = None,
311318
) -> List[Node]:
312319
"""
313320
Gets all upstreams of the given node(s), filterable by node type.
@@ -320,10 +327,14 @@ async def get_upstream_nodes(
320327
# Normalize to list
321328
node_names = [node_name] if isinstance(node_name, str) else node_name
322329

330+
# Use full options if none provided (for REST API DAGNodeOutput compatibility)
331+
result_options = options if options is not None else _node_output_options()
332+
333+
# Initial lookup always uses light options (only need type and current.id)
323334
nodes = await Node.get_by_names(
324335
session,
325336
node_names,
326-
options=_node_output_options(),
337+
options=[joinedload(Node.current)],
327338
)
328339

329340
if not nodes:
@@ -417,19 +428,19 @@ async def get_upstream_nodes(
417428
(Node.current_version == NodeRevision.version)
418429
& (Node.id == NodeRevision.node_id),
419430
)
420-
.options(*_node_output_options())
431+
.options(*result_options)
421432
)
422433

423434
results = list((await session.execute(statement)).unique().scalars().all())
424435

425436
# For metrics, include the immediate parents in the results
426437
# (they are the starting point for the CTE, so not included by default)
427438
if immediate_parent_ids:
428-
# Load parents with full options for consistent output
439+
# Load parents with same options for consistent output
429440
parent_query = (
430441
select(Node)
431442
.where(Node.id.in_(immediate_parent_ids))
432-
.options(*_node_output_options())
443+
.options(*result_options)
433444
)
434445
if not include_deactivated:
435446
parent_query = parent_query.where(is_(Node.deactivated_at, None))

datajunction-server/tests/api/graphql/downstream_nodes_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,75 @@ async def test_downstream_nodes_deactivated(
174174
{"name": "default.avg_repair_order_discounts", "type": "METRIC"},
175175
{"name": "default.avg_time_to_dispatch", "type": "METRIC"},
176176
]
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_downstream_nodes_with_nested_fields(
181+
module__client_with_roads: AsyncClient,
182+
) -> None:
183+
"""
184+
Test downstream nodes query with nested fields that require database joins.
185+
This tests that load_node_options correctly builds options based on the
186+
requested GraphQL fields (tags, owners, current.columns, etc.).
187+
"""
188+
# Query with many nested fields that require joins
189+
query = """
190+
{
191+
downstreamNodes(nodeNames: ["default.repair_order_details"], nodeType: TRANSFORM) {
192+
name
193+
type
194+
tags {
195+
name
196+
tagType
197+
}
198+
owners {
199+
username
200+
}
201+
current {
202+
displayName
203+
status
204+
description
205+
columns {
206+
name
207+
type
208+
}
209+
parents {
210+
name
211+
}
212+
}
213+
}
214+
}
215+
"""
216+
217+
response = await module__client_with_roads.post("/graphql", json={"query": query})
218+
assert response.status_code == 200
219+
data = response.json()
220+
221+
# Verify we got results
222+
downstreams = data["data"]["downstreamNodes"]
223+
assert len(downstreams) > 0
224+
225+
# Find the repair_orders_fact transform
226+
repair_orders_fact = next(
227+
(n for n in downstreams if n["name"] == "default.repair_orders_fact"),
228+
None,
229+
)
230+
assert repair_orders_fact is not None
231+
232+
# Verify nested fields are populated
233+
assert repair_orders_fact["type"] == "TRANSFORM"
234+
assert repair_orders_fact["current"] is not None
235+
assert repair_orders_fact["current"]["status"] == "VALID"
236+
assert repair_orders_fact["current"]["displayName"] == "Repair Orders Fact"
237+
238+
# Verify columns are loaded (requires selectinload)
239+
columns = repair_orders_fact["current"]["columns"]
240+
assert columns is not None
241+
assert len(columns) > 0
242+
column_names = {c["name"] for c in columns}
243+
assert "repair_order_id" in column_names
244+
245+
# Verify parents are loaded (requires selectinload)
246+
parents = repair_orders_fact["current"]["parents"]
247+
assert parents is not None
248+
assert len(parents) > 0

datajunction-server/tests/api/graphql/upstream_nodes_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,85 @@ async def test_upstream_nodes_deactivated(
237237
data = response.json()
238238
upstream_names = {node["name"] for node in data["data"]["upstreamNodes"]}
239239
assert "default.repair_orders_fact" in upstream_names
240+
241+
242+
@pytest.mark.asyncio
243+
async def test_upstream_nodes_with_nested_fields(
244+
module__client_with_roads: AsyncClient,
245+
) -> None:
246+
"""
247+
Test upstream nodes query with nested fields that require database joins.
248+
This tests that load_node_options correctly builds options based on the
249+
requested GraphQL fields (tags, owners, current.columns, etc.).
250+
"""
251+
# Query with many nested fields that require joins
252+
# Use includeDeactivated: true since earlier tests may have deactivated nodes
253+
query = """
254+
{
255+
upstreamNodes(nodeNames: ["default.num_repair_orders"], includeDeactivated: true) {
256+
name
257+
type
258+
tags {
259+
name
260+
tagType
261+
}
262+
owners {
263+
username
264+
}
265+
current {
266+
displayName
267+
status
268+
description
269+
columns {
270+
name
271+
type
272+
}
273+
parents {
274+
name
275+
}
276+
}
277+
}
278+
}
279+
"""
280+
281+
response = await module__client_with_roads.post("/graphql", json={"query": query})
282+
assert response.status_code == 200
283+
data = response.json()
284+
285+
# Verify we got results
286+
upstreams = data["data"]["upstreamNodes"]
287+
assert len(upstreams) > 0
288+
289+
# Find the repair_orders_fact transform (immediate parent of the metric)
290+
repair_orders_fact = next(
291+
(n for n in upstreams if n["name"] == "default.repair_orders_fact"),
292+
None,
293+
)
294+
assert repair_orders_fact is not None
295+
296+
# Verify nested fields are populated
297+
assert repair_orders_fact["type"] == "TRANSFORM"
298+
assert repair_orders_fact["current"] is not None
299+
assert repair_orders_fact["current"]["status"] == "VALID"
300+
assert repair_orders_fact["current"]["displayName"] == "Repair Orders Fact"
301+
302+
# Verify columns are loaded (requires selectinload)
303+
columns = repair_orders_fact["current"]["columns"]
304+
assert columns is not None
305+
assert len(columns) > 0
306+
column_names = {c["name"] for c in columns}
307+
assert "repair_order_id" in column_names
308+
309+
# Verify parents are loaded (requires selectinload)
310+
parents = repair_orders_fact["current"]["parents"]
311+
assert parents is not None
312+
assert len(parents) > 0
313+
314+
# Find a source node and verify its fields
315+
source_node = next(
316+
(n for n in upstreams if n["type"] == "SOURCE"),
317+
None,
318+
)
319+
assert source_node is not None
320+
assert source_node["current"] is not None
321+
assert source_node["current"]["columns"] is not None

0 commit comments

Comments
 (0)