Skip to content

Commit 053f245

Browse files
committed
refactor(incremental): simplify incremental graph by allowing mutations
Replicates graphql/graphql-js@f6227a8
1 parent 333d0d7 commit 053f245

File tree

4 files changed

+84
-132
lines changed

4 files changed

+84
-132
lines changed

src/graphql/execution/execute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2336,7 +2336,7 @@ def add_new_deferred_fragments(
23362336

23372337
# Instantiate the new record.
23382338
deferred_fragment_record = DeferredFragmentRecord(
2339-
parent, path, new_defer_usage.label
2339+
path, new_defer_usage.label, parent
23402340
)
23412341

23422342
# Update the map.

src/graphql/execution/incremental_graph.py

Lines changed: 53 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
)
1414
from collections import deque
1515
from contextlib import suppress
16-
from dataclasses import dataclass, field
1716
from typing import TYPE_CHECKING, Any, cast
1817

1918
from ..pyutils import BoxedAwaitableOrValue, Undefined, is_awaitable
2019
from .types import (
2120
BareStreamItemsResult,
2221
StreamItemsResult,
2322
StreamRecord,
23+
is_deferred_fragment_record,
2424
is_deferred_grouped_field_set_record,
2525
)
2626

2727
if TYPE_CHECKING:
2828
from collections.abc import AsyncGenerator, Awaitable, Generator, Iterable, Sequence
29-
from typing import TypeGuard
3029

3130
from ..error.graphql_error import GraphQLError
3231
from .types import (
@@ -41,40 +40,13 @@
4140
__all__ = ["IncrementalGraph"]
4241

4342

44-
@dataclass(frozen=True, repr=False, eq=False)
45-
class DeferredFragmentNode:
46-
"""A node representing a deferred fragment in the incremental graph."""
47-
48-
deferred_fragment_record: DeferredFragmentRecord
49-
deferred_grouped_field_set_records: dict[DeferredGroupedFieldSetRecord, None] = (
50-
field(default_factory=dict)
51-
)
52-
reconcilable_results: dict[ReconcilableDeferredGroupedFieldSetResult, None] = field(
53-
default_factory=dict
54-
)
55-
children: dict[SubsequentResultNode, None] = field(default_factory=dict)
56-
57-
58-
SubsequentResultNode = DeferredFragmentNode | StreamRecord
59-
60-
61-
def is_deferred_fragment_node(
62-
node: SubsequentResultNode | None,
63-
) -> TypeGuard[DeferredFragmentNode]:
64-
"""Check whether the given node is a deferred fragment node."""
65-
return isinstance(node, DeferredFragmentNode)
66-
67-
6843
class IncrementalGraph:
6944
"""Helper class to execute incremental Graphs.
7045
7146
For internal use only.
7247
"""
7348

74-
_root_nodes: dict[SubsequentResultNode, None]
75-
_deferred_fragment_nodes: dict[DeferredFragmentRecord, DeferredFragmentNode]
76-
_new_pending: dict[SubsequentResultNode, None]
77-
_new_incremental_data_records: dict[IncrementalDataRecord, None]
49+
_root_nodes: dict[SubsequentResultRecord, None]
7850
_completed_queue: list[IncrementalDataRecordResult]
7951
_next_queue: list[Future[Iterable[IncrementalDataRecordResult]]]
8052

@@ -83,9 +55,6 @@ class IncrementalGraph:
8355
def __init__(self) -> None:
8456
"""Initialize the IncrementalGraph."""
8557
self._root_nodes = {}
86-
self._deferred_fragment_nodes = {}
87-
self._new_pending = {}
88-
self._new_incremental_data_records = {}
8958
self._completed_queue = []
9059
self._next_queue = []
9160
self._tasks = set()
@@ -94,7 +63,7 @@ def get_new_root_nodes(
9463
self, incremental_data_records: Sequence[IncrementalDataRecord]
9564
) -> list[SubsequentResultRecord]:
9665
"""Get new root nodes."""
97-
initial_result_children: dict[SubsequentResultNode, None] = {}
66+
initial_result_children: dict[SubsequentResultRecord, None] = {}
9867
self._add_incremental_data_records(
9968
incremental_data_records, None, initial_result_children
10069
)
@@ -104,17 +73,19 @@ def add_completed_reconcilable_deferred_grouped_field_set(
10473
self, reconcilable_result: ReconcilableDeferredGroupedFieldSetResult
10574
) -> None:
10675
"""Add a completed reconcilable deferred grouped field set result."""
107-
record = reconcilable_result.deferred_grouped_field_set_record
108-
deferred = record.deferred_fragment_records
109-
for defererred_fragment_node in self._fragments_to_nodes(deferred):
110-
del defererred_fragment_node.deferred_grouped_field_set_records[
111-
reconcilable_result.deferred_grouped_field_set_record
76+
deferred_record = reconcilable_result.deferred_grouped_field_set_record
77+
deferred_records = deferred_record.deferred_fragment_records
78+
for defererred_fragment_record in deferred_records:
79+
del defererred_fragment_record.deferred_grouped_field_set_records[
80+
deferred_record
11281
]
113-
defererred_fragment_node.reconcilable_results[reconcilable_result] = None
82+
defererred_fragment_record.reconcilable_results[reconcilable_result] = None
11483

11584
incremental_data_records = reconcilable_result.incremental_data_records
11685
if incremental_data_records is not None:
117-
self._add_incremental_data_records(incremental_data_records, deferred)
86+
self._add_incremental_data_records(
87+
incremental_data_records, deferred_records
88+
)
11889

11990
async def completed_incremental_data(
12091
self,
@@ -150,26 +121,22 @@ def complete_deferred_fragment(
150121
| None
151122
):
152123
"""Complete a deferred fragment."""
153-
deferred_fragment_nodes = self._deferred_fragment_nodes
154-
try:
155-
deferred_fragment_node = deferred_fragment_nodes[deferred_fragment_record]
156-
except KeyError: # pragma: no cover
124+
if deferred_fragment_record not in self._root_nodes:
125+
return None # pragma: no cover
126+
if deferred_fragment_record.deferred_grouped_field_set_records:
157127
return None
158-
if deferred_fragment_node.deferred_grouped_field_set_records:
159-
return None
160-
reconcilable_results = list(deferred_fragment_node.reconcilable_results)
161-
self._remove_root_node(deferred_fragment_node)
128+
reconcilable_results = list(deferred_fragment_record.reconcilable_results)
129+
self._remove_root_node(deferred_fragment_record)
162130
for reconcilable_result in reconcilable_results:
163-
record = reconcilable_result.deferred_grouped_field_set_record
164-
for other_deferred_fragment_node in self._fragments_to_nodes(
165-
record.deferred_fragment_records
166-
):
131+
deferred_record = reconcilable_result.deferred_grouped_field_set_record
132+
deferred_records = deferred_record.deferred_fragment_records
133+
for other_deferred_fragment_record in deferred_records:
167134
with suppress(KeyError):
168-
del other_deferred_fragment_node.reconcilable_results[
135+
del other_deferred_fragment_record.reconcilable_results[
169136
reconcilable_result
170137
]
171138
new_root_nodes = self._promote_non_empty_to_root(
172-
deferred_fragment_node.children
139+
deferred_fragment_record.children
173140
)
174141
return new_root_nodes, reconcilable_results
175142

@@ -178,16 +145,9 @@ def remove_deferred_fragment(
178145
deferred_fragment_record: DeferredFragmentRecord,
179146
) -> bool:
180147
"""Check if deferred fragment exists and remove it in that case."""
181-
deferred_fragment_nodes = self._deferred_fragment_nodes
182-
try:
183-
deferred_fragment_node = deferred_fragment_nodes[deferred_fragment_record]
184-
except KeyError: # pragma: no cover
148+
if deferred_fragment_record not in self._root_nodes:
185149
return False
186-
self._remove_root_node(deferred_fragment_node)
187-
del deferred_fragment_nodes[deferred_fragment_record]
188-
for child in deferred_fragment_node.children: # pragma: no cover
189-
if is_deferred_fragment_node(child):
190-
self.remove_deferred_fragment(child.deferred_fragment_record)
150+
self._remove_root_node(deferred_fragment_record)
191151
return True
192152

193153
def remove_stream(self, stream_record: StreamRecord) -> None:
@@ -199,28 +159,29 @@ def stop_incremental_data(self) -> None:
199159
for future in self._next_queue:
200160
future.cancel() # pragma: no cover
201161

202-
def _remove_root_node(self, subsequent_result_node: SubsequentResultNode) -> None:
162+
def _remove_root_node(
163+
self, subsequent_result_record: SubsequentResultRecord
164+
) -> None:
203165
"""Remove root node."""
204-
del self._root_nodes[subsequent_result_node]
166+
del self._root_nodes[subsequent_result_record]
205167
if not self._root_nodes:
206168
self.stop_incremental_data()
207169

208170
def _add_incremental_data_records(
209171
self,
210172
incremental_data_records: Sequence[IncrementalDataRecord],
211173
parents: Sequence[DeferredFragmentRecord] | None = None,
212-
initial_result_children: dict[SubsequentResultNode, None] | None = None,
174+
initial_result_children: dict[SubsequentResultRecord, None] | None = None,
213175
) -> None:
214176
"""Add incremental data records."""
215177
for incremental_data_record in incremental_data_records:
216178
if is_deferred_grouped_field_set_record(incremental_data_record):
217-
for (
218-
deferred_fragment_record
219-
) in incremental_data_record.deferred_fragment_records:
220-
deferred_fragment_node = self._add_deferred_fragment_node(
179+
deferred_records = incremental_data_record.deferred_fragment_records
180+
for deferred_fragment_record in deferred_records:
181+
self._add_deferred_fragment_node(
221182
deferred_fragment_record, initial_result_children
222183
)
223-
deferred_fragment_node.deferred_grouped_field_set_records[
184+
deferred_fragment_record.deferred_grouped_field_set_records[
224185
incremental_data_record
225186
] = None
226187
if self._completes_root_node(incremental_data_record):
@@ -234,23 +195,21 @@ def _add_incremental_data_records(
234195
] = None
235196
else:
236197
for parent in parents:
237-
deferred_fragment_node = self._add_deferred_fragment_node(
238-
parent, initial_result_children
198+
self._add_deferred_fragment_node(parent, initial_result_children)
199+
parent.children[cast("StreamRecord", incremental_data_record)] = (
200+
None
239201
)
240-
deferred_fragment_node.children[
241-
cast("StreamRecord", incremental_data_record)
242-
] = None
243202

244203
def _promote_non_empty_to_root(
245-
self, maybe_empty_new_root_nodes: dict[SubsequentResultNode, None]
204+
self, maybe_empty_new_root_nodes: dict[SubsequentResultRecord, None]
246205
) -> list[SubsequentResultRecord]:
247206
"""Promote non-empty nodes to root nodes."""
248207
new_root_nodes: list[SubsequentResultRecord] = []
249208
# use a deque to simulate how JavaScripts iterates over a changing set
250209
unprocessed_nodes = deque(maybe_empty_new_root_nodes)
251210
while unprocessed_nodes:
252211
node = unprocessed_nodes.popleft()
253-
if is_deferred_fragment_node(node):
212+
if is_deferred_fragment_record(node):
254213
if node.deferred_grouped_field_set_records:
255214
for (
256215
deferred_grouped_field_set_record
@@ -262,9 +221,8 @@ def _promote_non_empty_to_root(
262221
deferred_grouped_field_set_record
263222
)
264223
self._root_nodes[node] = None
265-
new_root_nodes.append(node.deferred_fragment_record)
224+
new_root_nodes.append(node)
266225
continue
267-
del self._deferred_fragment_nodes[node.deferred_fragment_record]
268226
for child in node.children:
269227
if child not in maybe_empty_new_root_nodes: # pragma: no branch
270228
maybe_empty_new_root_nodes[cast("StreamRecord", child)] = None
@@ -280,54 +238,26 @@ def _completes_root_node(
280238
) -> bool:
281239
"""Check whether the given record completes a root node."""
282240
root_nodes = self._root_nodes
283-
return any(
284-
node in root_nodes
285-
for node in self._fragments_to_nodes(
286-
deferred_grouped_field_set_record.deferred_fragment_records
287-
)
288-
)
289-
290-
def _fragments_to_nodes(
291-
self,
292-
deferred_fragment_records: Sequence[DeferredFragmentRecord],
293-
) -> list[DeferredFragmentNode]:
294-
"""Get deferred fragment nodes for the given records."""
295-
return [
296-
node
297-
for node in (
298-
self._deferred_fragment_nodes.get(deferred_fragment_record)
299-
for deferred_fragment_record in deferred_fragment_records
300-
)
301-
if is_deferred_fragment_node(node)
302-
]
241+
deferred_records = deferred_grouped_field_set_record.deferred_fragment_records
242+
return any(record in root_nodes for record in deferred_records)
303243

304244
def _add_deferred_fragment_node(
305245
self,
306246
deferred_fragment_record: DeferredFragmentRecord,
307-
initial_result_children: dict[SubsequentResultNode, None] | None = None,
308-
) -> DeferredFragmentNode:
247+
initial_result_children: dict[SubsequentResultRecord, None] | None = None,
248+
) -> None:
309249
"""Add a deferred fragment node."""
310-
try:
311-
deferred_fragment_node = self._deferred_fragment_nodes[
312-
deferred_fragment_record
313-
]
314-
except KeyError as key_error:
315-
deferred_fragment_node = DeferredFragmentNode(deferred_fragment_record)
316-
self._deferred_fragment_nodes[deferred_fragment_record] = (
317-
deferred_fragment_node
318-
)
319-
parent = deferred_fragment_record.parent
320-
if parent is None:
250+
if deferred_fragment_record in self._root_nodes:
251+
return
252+
parent = deferred_fragment_record.parent
253+
if parent is None:
254+
if initial_result_children is None: # pragma: no cover
321255
msg = "Invalid state while adding deferred fragment node."
322-
if initial_result_children is None: # pragma: no cover
323-
raise RuntimeError(msg) from key_error
324-
initial_result_children[deferred_fragment_node] = None
325-
else:
326-
parent_node = self._add_deferred_fragment_node(
327-
parent, initial_result_children
328-
)
329-
parent_node.children[deferred_fragment_node] = None
330-
return deferred_fragment_node
256+
raise RuntimeError(msg)
257+
initial_result_children[deferred_fragment_record] = None
258+
return
259+
parent.children[deferred_fragment_record] = None
260+
self._add_deferred_fragment_node(parent, initial_result_children)
331261

332262
def _on_deferred_grouped_field_set(
333263
self, deferred_grouped_field_set_record: DeferredGroupedFieldSetRecord

src/graphql/execution/types.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"SubsequentIncrementalExecutionResult",
5252
"SubsequentResultRecord",
5353
"is_cancellable_stream_record",
54+
"is_deferred_fragment_record",
5455
"is_deferred_grouped_field_set_record",
5556
"is_deferred_grouped_field_set_result",
5657
"is_non_reconcilable_deferred_grouped_field_set_result",
@@ -796,36 +797,57 @@ def __init__(
796797
class DeferredFragmentRecord:
797798
"""Deferred fragment record"""
798799

799-
parent: DeferredFragmentRecord | None
800800
path: Path | None
801801
label: str | None
802802
id: str | None
803+
parent: DeferredFragmentRecord | None
804+
deferred_grouped_field_set_records: dict[DeferredGroupedFieldSetRecord, None]
805+
reconcilable_results: dict[ReconcilableDeferredGroupedFieldSetResult, None]
806+
children: dict[SubsequentResultRecord, None]
803807

804-
__slots__ = "id", "label", "parent", "path"
808+
__slots__ = (
809+
"children",
810+
"deferred_grouped_field_set_records",
811+
"id",
812+
"label",
813+
"parent",
814+
"path",
815+
"reconcilable_results",
816+
)
805817

806818
def __init__(
807819
self,
808-
parent: DeferredFragmentRecord | None = None,
809820
path: Path | None = None,
810821
label: str | None = None,
822+
parent: DeferredFragmentRecord | None = None,
811823
) -> None:
812-
self.parent = parent
813824
self.path = path
814825
self.label = label
826+
self.parent = parent
815827
self.id = None
828+
self.deferred_grouped_field_set_records = {}
829+
self.reconcilable_results = {}
830+
self.children = {}
816831

817832
def __repr__(self) -> str:
818833
name = self.__class__.__name__
819834
args: list[str] = []
820-
if self.parent:
821-
args.append("parent")
822835
if self.path:
823836
args.append(f"path={self.path.as_list()!r}")
824837
if self.label:
825838
args.append(f"label={self.label!r}")
839+
if self.parent:
840+
args.append("parent")
826841
return f"{name}({', '.join(args)})"
827842

828843

844+
def is_deferred_fragment_record(
845+
subsequent_result_record: SubsequentResultRecord,
846+
) -> TypeGuard[DeferredFragmentRecord]:
847+
"""Check if the subsequent result record is a deferred fragment record."""
848+
return isinstance(subsequent_result_record, DeferredFragmentRecord)
849+
850+
829851
class StreamItemResult(NamedTuple):
830852
"""Stream item result"""
831853

0 commit comments

Comments
 (0)