Skip to content

Commit 3aae23a

Browse files
committed
Merge pull request #36 from jhgg/clean-up-no-undefined-variables
Clean up no undefined variables
2 parents e161cf3 + 82b3209 commit 3aae23a

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

graphql/core/validation/rules.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,12 @@ def collect_fragment_spread_nodes(self):
309309

310310

311311
class NoUndefinedVariables(ValidationRule):
312+
visit_spread_fragments = True
313+
312314
def __init__(self, context):
313315
self.operation = None
314-
self.visited_fragment_names = {}
315-
self.defined_variable_names = {}
316-
self.visit_spread_fragments = True
316+
self.visited_fragment_names = set()
317+
self.defined_variable_names = set()
317318
super(NoUndefinedVariables, self).__init__(context)
318319

319320
@staticmethod
@@ -328,22 +329,22 @@ def undefined_var_by_op_message(var_name, op_name):
328329

329330
def enter_OperationDefinition(self, node, *args):
330331
self.operation = node
331-
self.visited_fragment_names = {}
332-
self.defined_variable_names = {}
332+
self.visited_fragment_names = set()
333+
self.defined_variable_names = set()
333334

334335
def enter_VariableDefinition(self, node, *args):
335-
self.defined_variable_names[node.variable.name.value] = True
336+
self.defined_variable_names.add(node.variable.name.value)
336337

337338
def enter_Variable(self, variable, key, parent, path, ancestors):
338339
var_name = variable.name.value
339340
if var_name not in self.defined_variable_names:
340-
is_fragment = lambda node: isinstance(node, ast.FragmentDefinition)
341-
within_fragment = any(is_fragment(node) for node in ancestors)
341+
within_fragment = any(isinstance(node, ast.FragmentDefinition) for node in ancestors)
342342
if within_fragment and self.operation and self.operation.name:
343343
return GraphQLError(
344344
self.undefined_var_by_op_message(var_name, self.operation.name.value),
345345
[variable, self.operation]
346346
)
347+
347348
return GraphQLError(
348349
self.undefined_var_message(var_name),
349350
[variable]
@@ -352,7 +353,8 @@ def enter_Variable(self, variable, key, parent, path, ancestors):
352353
def enter_FragmentSpread(self, spread_ast, *args):
353354
if spread_ast.name.value in self.visited_fragment_names:
354355
return False
355-
self.visited_fragment_names[spread_ast.name.value] = True
356+
357+
self.visited_fragment_names.add(spread_ast.name.value)
356358

357359

358360
class NoUnusedVariables(ValidationRule):

0 commit comments

Comments
 (0)