Skip to content

Commit 6651507

Browse files
committed
[Validation] Factor out and memoize recursively referenced fragments.
Related GraphQL commit: graphql/graphql-js@eef8d97
1 parent 7c3e769 commit 6651507

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

graphql/core/validation/context.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,39 @@
22

33

44
class ValidationContext(object):
5-
__slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads'
5+
__slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments'
66

77
def __init__(self, schema, ast, type_info):
88
self._schema = schema
99
self._ast = ast
1010
self._type_info = type_info
1111
self._fragments = None
1212
self._fragment_spreads = {}
13+
self._recursively_referenced_fragments = {}
1314

1415
def get_schema(self):
1516
return self._schema
1617

18+
def get_recursively_referenced_fragments(self, operation):
19+
fragments = self._recursively_referenced_fragments.get(operation)
20+
if not fragments:
21+
fragments = []
22+
collected_names = set()
23+
nodes_to_visit = [operation]
24+
while nodes_to_visit:
25+
node = nodes_to_visit.pop()
26+
spreads = self.get_fragment_spreads(node)
27+
for spread in spreads:
28+
frag_name = spread.name.value
29+
if frag_name not in collected_names:
30+
collected_names.add(frag_name)
31+
fragment = self.get_fragment(frag_name)
32+
if fragment:
33+
fragments.append(fragment)
34+
nodes_to_visit.append(fragment)
35+
self._recursively_referenced_fragments[operation] = fragments
36+
return fragments
37+
1738
def get_fragment_spreads(self, node):
1839
spreads = self._fragment_spreads.get(node)
1940
if not spreads:

graphql/core/validation/rules/no_unused_fragments.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44

55
class NoUnusedFragments(ValidationRule):
6-
__slots__ = 'fragment_definitions', 'spreads_within_operation', 'fragment_adjacencies', 'spread_names'
6+
__slots__ = 'fragment_definitions', 'operation_definitions', 'fragment_adjacencies', 'spread_names'
77

88
def __init__(self, context):
99
super(NoUnusedFragments, self).__init__(context)
10-
self.spreads_within_operation = []
10+
self.operation_definitions = []
1111
self.fragment_definitions = []
1212

1313
def enter_OperationDefinition(self, node, key, parent, path, ancestors):
14-
self.spreads_within_operation.append(self.context.get_fragment_spreads(node))
14+
self.operation_definitions.append(node)
1515
return False
1616

1717
def enter_FragmentDefinition(self, node, key, parent, path, ancestors):
@@ -21,19 +21,10 @@ def enter_FragmentDefinition(self, node, key, parent, path, ancestors):
2121
def leave_Document(self, node, key, parent, path, ancestors):
2222
fragment_names_used = set()
2323

24-
def reduce_spread_fragments(spreads):
25-
for spread in spreads:
26-
frag_name = spread.name.value
27-
if frag_name in fragment_names_used:
28-
continue
29-
30-
fragment_names_used.add(frag_name)
31-
fragment = self.context.get_fragment(frag_name)
32-
if fragment:
33-
reduce_spread_fragments(self.context.get_fragment_spreads(fragment))
34-
35-
for spreads in self.spreads_within_operation:
36-
reduce_spread_fragments(spreads)
24+
for operation in self.operation_definitions:
25+
fragments = self.context.get_recursively_referenced_fragments(operation)
26+
for fragment in fragments:
27+
fragment_names_used.add(fragment.name.value)
3728

3829
errors = [
3930
GraphQLError(

0 commit comments

Comments
 (0)