Skip to content

Commit 9032ba8

Browse files
author
rawls238
committed
first pass
1 parent 045390b commit 9032ba8

File tree

3 files changed

+81
-5
lines changed

3 files changed

+81
-5
lines changed

graphql/core/validation/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def enter(self, node, key, parent, path, ancestors):
6868
result = False
6969

7070
if result is None and getattr(self.instance, 'visit_spread_fragments', False) and isinstance(node, FragmentSpread):
71-
fragment = self.instance.context.get_fragment(node.name.value)
72-
if fragment:
73-
visit(fragment, self)
71+
try:
72+
fragment = self.instance.context.get_fragment(node.name.value)
73+
except KeyError:
74+
pass
75+
else:
76+
if fragment:
77+
visit(fragment, self)
7478

7579
if result is False:
7680
self.type_info.leave(node)

graphql/core/validation/rules.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ..utils import type_from_ast, is_valid_literal_value
22
from ..error import GraphQLError
3-
from ..type.definition import is_composite_type, is_input_type, is_leaf_type, GraphQLNonNull
3+
from ..type.definition import is_composite_type, is_input_type, is_leaf_type, GraphQLNonNull, GraphQLList
44
from ..language import ast
55
from ..language.visitor import Visitor, visit
66
from ..language.printer import print_ast
@@ -623,7 +623,54 @@ def bad_value_for_default_arg_message(var_name, type, value):
623623

624624

625625
class VariablesInAllowedPosition(ValidationRule):
626-
pass
626+
visit_spread_fragments = True
627+
628+
def __init__(self, context):
629+
super(VariablesInAllowedPosition, self).__init__(context)
630+
self.var_def_map = {}
631+
self.visited_fragment_names = {}
632+
633+
def enter_OperationDefinition(self, *args):
634+
self.var_def_map = {}
635+
self.visited_fragment_names = {}
636+
637+
def enter_VariableDefinition(self, var_def_ast, *args):
638+
self.var_def_map[var_def_ast.variable.name.value] = var_def_ast
639+
640+
def enter_Variable(self, variable_ast, *args):
641+
var_name = variable_ast.name.value
642+
var_def = self.var_def_map[var_name]
643+
var_type = var_def and type_from_ast(self.context.get_schema(), var_def.type)
644+
input_type = self.context.get_input_type()
645+
if var_type and input_type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), input_type):
646+
return GraphQlError(self.bad_var_pos_message(var_name, var_type, input_type, [variable_ast]))
647+
648+
def enter_FragmentSpread(self, spread_ast, *args):
649+
if spread_ast.name.value in self.visited_fragment_names:
650+
return False
651+
self.visited_fragment_names[spread_ast.name.value] = True;
652+
653+
@staticmethod
654+
def effective_type(var_type, var_def):
655+
if not var_def.default_value or isinstance(var_def, GraphQLNonNull):
656+
return var_type
657+
return GraphQLNonNull(var_type)
658+
659+
@staticmethod
660+
def var_type_allowed_for_type(var_type, expected_type):
661+
if isinstance(expected_type, GraphQLNonNull):
662+
if isinstance(var_type, GraphQLNonNull):
663+
return VariablesInAllowedPosition.var_type_allowed_for_type(var_type.of_type, expected_type.of_type)
664+
return False
665+
if isinstance(var_type, GraphQLNonNull):
666+
return VariablesInAllowedPosition.var_type_allowed_for_type(var_type.of_type, expected_type)
667+
if isinstance(var_type, GraphQLList) and isinstance(expected_type, GraphQLList):
668+
return VariablesInAllowedPosition.var_type_allowed_for_type(var_type.of_type, expected_type.of_type)
669+
return var_type == expected_type
670+
671+
@staticmethod
672+
def bad_var_pos_message(var_name, var_type, expected_type):
673+
return 'Variable {} of type {} used in position expecting type {}'.format(var_name, var_type, expected_type)
627674

628675

629676
class OverlappingFieldsCanBeMerged(ValidationRule):
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from graphql.core.language.location import SourceLocation
2+
from graphql.core.validation.rules import VariablesInAllowedPosition
3+
from utils import expect_passes_rule, expect_fails_rule
4+
5+
def test_boolean_boolean():
6+
expect_passes_rule(VariablesInAllowedPosition, '''
7+
query Query($booleanArg: Boolean)
8+
{
9+
complicatedArgs {
10+
booleanArgField(booleanArg: $booleanArg)
11+
}
12+
}
13+
''')
14+
def test_boolean_boolean_in_fragment():
15+
expect_passes_rule(VariablesInAllowedPosition, '''
16+
fragment booleanArgFrag on ComplicatedArgs {
17+
booleanArgField(booleanArg: $booleanArg)
18+
}
19+
query Query($booleanArg: Boolean)
20+
{
21+
complicatedArgs {
22+
...booleanArgFrag
23+
}
24+
}
25+
''')

0 commit comments

Comments
 (0)