@@ -634,43 +634,53 @@ def enter_OperationDefinition(self, *args):
634
634
self .var_def_map = {}
635
635
self .visited_fragment_names = set ()
636
636
637
- def enter_VariableDefinition (self , var_def_ast , * args ):
638
- self .var_def_map [var_def_ast .variable .name .value ] = var_def_ast
637
+ def enter_VariableDefinition (self , node , * args ):
638
+ self .var_def_map [node .variable .name .value ] = node
639
+
640
+ def enter_Variable (self , node , * args ):
641
+ var_name = node .name .value
642
+ var_def = self .var_def_map .get (var_name )
639
643
640
- def enter_Variable (self , variable_ast , * args ):
641
- var_name = variable_ast .name .value
642
- var_def = self .var_def_map [var_name ]
643
644
var_type = var_def and type_from_ast (self .context .get_schema (), var_def .type )
644
645
input_type = self .context .get_input_type ()
645
- if var_type and input_type and not self .var_type_allowed_for_type (self .effective_type (var_type , var_def ), input_type ):
646
- return GraphQlError (self .bad_var_pos_message (var_name , var_type , input_type , [variable_ast ]))
647
646
648
- def enter_FragmentSpread (self , spread_ast , * args ):
649
- if spread_ast .name .value in self .visited_fragment_names :
647
+ if var_type and input_type and not self .var_type_allowed_for_type (self .effective_type (var_type , var_def ),
648
+ input_type ):
649
+ return GraphQLError (self .bad_var_pos_message (var_name , var_type , input_type ),
650
+ [node ])
651
+
652
+ def enter_FragmentSpread (self , node , * args ):
653
+ if node .name .value in self .visited_fragment_names :
650
654
return False
651
- self .visited_fragment_names .add (spread_ast .name .value );
655
+
656
+ self .visited_fragment_names .add (node .name .value )
652
657
653
658
@staticmethod
654
659
def effective_type (var_type , var_def ):
655
660
if not var_def .default_value or isinstance (var_def , GraphQLNonNull ):
656
661
return var_type
662
+
657
663
return GraphQLNonNull (var_type )
658
664
659
665
@staticmethod
660
666
def var_type_allowed_for_type (var_type , expected_type ):
661
667
if isinstance (expected_type , GraphQLNonNull ):
662
668
if isinstance (var_type , GraphQLNonNull ):
663
669
return VariablesInAllowedPosition .var_type_allowed_for_type (var_type .of_type , expected_type .of_type )
670
+
664
671
return False
672
+
665
673
if isinstance (var_type , GraphQLNonNull ):
666
674
return VariablesInAllowedPosition .var_type_allowed_for_type (var_type .of_type , expected_type )
675
+
667
676
if isinstance (var_type , GraphQLList ) and isinstance (expected_type , GraphQLList ):
668
677
return VariablesInAllowedPosition .var_type_allowed_for_type (var_type .of_type , expected_type .of_type )
678
+
669
679
return var_type == expected_type
670
680
671
681
@staticmethod
672
682
def bad_var_pos_message (var_name , var_type , expected_type ):
673
- return 'Variable {} of type {} used in position expecting type {} ' .format (var_name , var_type , expected_type )
683
+ return 'Variable "${}" of type "{}" used in position expecting type "{}". ' .format (var_name , var_type , expected_type )
674
684
675
685
676
686
class OverlappingFieldsCanBeMerged (ValidationRule ):
0 commit comments