|
1 | 1 | from ..utils import type_from_ast, is_valid_literal_value
|
2 | 2 | from ..error import GraphQLError
|
3 |
| -from ..type.definition import is_composite_type, is_input_type, is_leaf_type, GraphQLNonNull |
| 3 | +from ..type.definition import is_composite_type, is_input_type, is_leaf_type, GraphQLNonNull, GraphQLList |
4 | 4 | from ..language import ast
|
5 | 5 | from ..language.visitor import Visitor, visit
|
6 | 6 | from ..language.printer import print_ast
|
@@ -623,7 +623,54 @@ def bad_value_for_default_arg_message(var_name, type, value):
|
623 | 623 |
|
624 | 624 |
|
625 | 625 | class VariablesInAllowedPosition(ValidationRule):
|
626 |
| - pass |
| 626 | + visit_spread_fragments = True |
| 627 | + |
| 628 | + def __init__(self, context): |
| 629 | + super(VariablesInAllowedPosition, self).__init__(context) |
| 630 | + self.var_def_map = {} |
| 631 | + self.visited_fragment_names = {} |
| 632 | + |
| 633 | + def enter_OperationDefinition(self, *args): |
| 634 | + self.var_def_map = {} |
| 635 | + self.visited_fragment_names = {} |
| 636 | + |
| 637 | + def enter_VariableDefinition(self, var_def_ast, *args): |
| 638 | + self.var_def_map[var_def_ast.variable.name.value] = var_def_ast |
| 639 | + |
| 640 | + def enter_Variable(self, variable_ast, *args): |
| 641 | + var_name = variable_ast.name.value |
| 642 | + var_def = self.var_def_map[var_name] |
| 643 | + var_type = var_def and type_from_ast(self.context.get_schema(), var_def.type) |
| 644 | + input_type = self.context.get_input_type() |
| 645 | + if var_type and input_type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), input_type): |
| 646 | + return GraphQlError(self.bad_var_pos_message(var_name, var_type, input_type, [variable_ast])) |
| 647 | + |
| 648 | + def enter_FragmentSpread(self, spread_ast, *args): |
| 649 | + if spread_ast.name.value in self.visited_fragment_names: |
| 650 | + return False |
| 651 | + self.visited_fragment_names[spread_ast.name.value] = True; |
| 652 | + |
| 653 | + @staticmethod |
| 654 | + def effective_type(var_type, var_def): |
| 655 | + if not var_def.default_value or isinstance(var_def, GraphQLNonNull): |
| 656 | + return var_type |
| 657 | + return GraphQLNonNull(var_type) |
| 658 | + |
| 659 | + @staticmethod |
| 660 | + def var_type_allowed_for_type(var_type, expected_type): |
| 661 | + if isinstance(expected_type, GraphQLNonNull): |
| 662 | + if isinstance(var_type, GraphQLNonNull): |
| 663 | + return VariablesInAllowedPosition.var_type_allowed_for_type(var_type.of_type, expected_type.of_type) |
| 664 | + return False |
| 665 | + if isinstance(var_type, GraphQLNonNull): |
| 666 | + return VariablesInAllowedPosition.var_type_allowed_for_type(var_type.of_type, expected_type) |
| 667 | + if isinstance(var_type, GraphQLList) and isinstance(expected_type, GraphQLList): |
| 668 | + return VariablesInAllowedPosition.var_type_allowed_for_type(var_type.of_type, expected_type.of_type) |
| 669 | + return var_type == expected_type |
| 670 | + |
| 671 | + @staticmethod |
| 672 | + def bad_var_pos_message(var_name, var_type, expected_type): |
| 673 | + return 'Variable {} of type {} used in position expecting type {}'.format(var_name, var_type, expected_type) |
627 | 674 |
|
628 | 675 |
|
629 | 676 | class OverlappingFieldsCanBeMerged(ValidationRule):
|
|
0 commit comments