@@ -309,11 +309,12 @@ def collect_fragment_spread_nodes(self):
309
309
310
310
311
311
class NoUndefinedVariables (ValidationRule ):
312
+ visit_spread_fragments = True
313
+
312
314
def __init__ (self , context ):
313
315
self .operation = None
314
- self .visited_fragment_names = {}
315
- self .defined_variable_names = {}
316
- self .visit_spread_fragments = True
316
+ self .visited_fragment_names = set ()
317
+ self .defined_variable_names = set ()
317
318
super (NoUndefinedVariables , self ).__init__ (context )
318
319
319
320
@staticmethod
@@ -328,22 +329,22 @@ def undefined_var_by_op_message(var_name, op_name):
328
329
329
330
def enter_OperationDefinition (self , node , * args ):
330
331
self .operation = node
331
- self .visited_fragment_names = {}
332
- self .defined_variable_names = {}
332
+ self .visited_fragment_names = set ()
333
+ self .defined_variable_names = set ()
333
334
334
335
def enter_VariableDefinition (self , node , * args ):
335
- self .defined_variable_names [ node .variable .name .value ] = True
336
+ self .defined_variable_names . add ( node .variable .name .value )
336
337
337
338
def enter_Variable (self , variable , key , parent , path , ancestors ):
338
339
var_name = variable .name .value
339
340
if var_name not in self .defined_variable_names :
340
- is_fragment = lambda node : isinstance (node , ast .FragmentDefinition )
341
- within_fragment = any (is_fragment (node ) for node in ancestors )
341
+ within_fragment = any (isinstance (node , ast .FragmentDefinition ) for node in ancestors )
342
342
if within_fragment and self .operation and self .operation .name :
343
343
return GraphQLError (
344
344
self .undefined_var_by_op_message (var_name , self .operation .name .value ),
345
345
[variable , self .operation ]
346
346
)
347
+
347
348
return GraphQLError (
348
349
self .undefined_var_message (var_name ),
349
350
[variable ]
@@ -352,7 +353,8 @@ def enter_Variable(self, variable, key, parent, path, ancestors):
352
353
def enter_FragmentSpread (self , spread_ast , * args ):
353
354
if spread_ast .name .value in self .visited_fragment_names :
354
355
return False
355
- self .visited_fragment_names [spread_ast .name .value ] = True
356
+
357
+ self .visited_fragment_names .add (spread_ast .name .value )
356
358
357
359
358
360
class NoUnusedVariables (ValidationRule ):
@@ -511,7 +513,18 @@ def duplicate_arg_message(field):
511
513
512
514
513
515
class ArgumentsOfCorrectType (ValidationRule ):
514
- pass
516
+ def enter_Argument (self , node , * args ):
517
+ arg_def = self .context .get_argument ()
518
+ if arg_def and not is_valid_literal_value (arg_def .type , node .value ):
519
+ return GraphQLError (
520
+ self .bad_value_message (node .name .value , arg_def .type ,
521
+ print_ast (node .value )),
522
+ [node .value ]
523
+ )
524
+
525
+ @staticmethod
526
+ def bad_value_message (arg_name , type , value ):
527
+ return 'Argument "{}" expected type "{}" but got: {}.' .format (arg_name , type , value )
515
528
516
529
517
530
class ProvidedNonNullArguments (ValidationRule ):
0 commit comments