Skip to content

Commit 2c9f5e1

Browse files
committed
Validate directive arguments inside SDL
Replicates graphql/graphql-js@4631744
1 parent 330c5e6 commit 2c9f5e1

10 files changed

+313
-61
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ a query language for APIs created by Facebook.
1313

1414
The current version 1.0.0rc2 of GraphQL-core-next is up-to-date with GraphQL.js
1515
version 14.0.0rc2. All parts of the API are covered by an extensive test
16-
suite of currently 1585 unit tests.
16+
suite of currently 1602 unit tests.
1717

1818

1919
## Documentation
Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import List
1+
from typing import Dict, List, Union
22

33
from ...error import GraphQLError
4-
from ...language import ArgumentNode, FieldNode, DirectiveNode
4+
from ...language import (
5+
ArgumentNode, FieldNode, DirectiveDefinitionNode, DirectiveNode, SKIP)
56
from ...pyutils import quoted_or_list, suggestion_list
6-
from . import ValidationRule
7+
from ...type import specified_directives
8+
from . import ASTValidationRule, SDLValidationContext, ValidationContext
79

810
__all__ = [
9-
'KnownArgumentNamesRule',
11+
'KnownArgumentNamesRule', 'KnownArgumentNamesOnDirectivesRule',
1012
'unknown_arg_message', 'unknown_directive_arg_message']
1113

1214

@@ -30,38 +32,72 @@ def unknown_directive_arg_message(
3032
return message
3133

3234

33-
class KnownArgumentNamesRule(ValidationRule):
35+
class KnownArgumentNamesOnDirectivesRule(ASTValidationRule):
36+
"""Known argument names on directives
37+
38+
A GraphQL directive is only valid if all supplied arguments are defined.
39+
"""
40+
41+
context: Union[ValidationContext, SDLValidationContext]
42+
43+
def __init__(self, context: Union[
44+
ValidationContext, SDLValidationContext]) -> None:
45+
super().__init__(context)
46+
directive_args: Dict[str, List[str]] = {}
47+
48+
schema = context.schema
49+
defined_directives = (
50+
schema.directives if schema else specified_directives)
51+
for directive in defined_directives:
52+
directive_args[directive.name] = list(directive.args)
53+
54+
ast_definitions = context.document.definitions
55+
for def_ in ast_definitions:
56+
if isinstance(def_, DirectiveDefinitionNode):
57+
directive_args[def_.name.value] = [
58+
arg.name.value for arg in def_.arguments
59+
] if def_.arguments else []
60+
61+
self.directive_args = directive_args
62+
63+
def enter_directive(self, directive_node: DirectiveNode, *_args):
64+
directive_name = directive_node.name.value
65+
known_args = self.directive_args.get(directive_name)
66+
if directive_node.arguments and known_args:
67+
for arg_node in directive_node.arguments:
68+
arg_name = arg_node.name.value
69+
if arg_name not in known_args:
70+
suggestions = suggestion_list(arg_name, known_args)
71+
self.report_error(GraphQLError(
72+
unknown_directive_arg_message(
73+
arg_name, directive_name, suggestions), arg_node))
74+
return SKIP
75+
76+
77+
class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule):
3478
"""Known argument names
3579
3680
A GraphQL field is only valid if all supplied arguments are defined by
3781
that field.
3882
"""
3983

84+
context: ValidationContext
85+
86+
def __init__(self, context: ValidationContext) -> None:
87+
super().__init__(context)
88+
4089
def enter_argument(
41-
self, node: ArgumentNode, _key, _parent, _path, ancestors):
90+
self, arg_node: ArgumentNode, *args):
4291
context = self.context
4392
arg_def = context.get_argument()
44-
if not arg_def:
45-
argument_of = ancestors[-1]
46-
if isinstance(argument_of, FieldNode):
47-
field_def = context.get_field_def()
48-
parent_type = context.get_parent_type()
49-
if field_def and parent_type:
50-
context.report_error(GraphQLError(
51-
unknown_arg_message(
52-
node.name.value,
53-
argument_of.name.value,
54-
parent_type.name,
55-
suggestion_list(
56-
node.name.value, list(field_def.args))),
57-
[node]))
58-
elif isinstance(argument_of, DirectiveNode):
59-
directive = context.get_directive()
60-
if directive:
61-
context.report_error(GraphQLError(
62-
unknown_directive_arg_message(
63-
node.name.value,
64-
directive.name,
65-
suggestion_list(
66-
node.name.value, list(directive.args))),
67-
[node]))
93+
field_def = context.get_field_def()
94+
parent_type = context.get_parent_type()
95+
if not arg_def and field_def and parent_type:
96+
arg_name = arg_node.name.value
97+
field_name = args[3][-1].name.value
98+
known_args_names = list(field_def.args)
99+
context.report_error(GraphQLError(
100+
unknown_arg_message(
101+
arg_name, field_name, parent_type.name,
102+
suggestion_list(arg_name, known_args_names)), arg_node))
103+

graphql/validation/rules/known_directives.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ class KnownDirectivesRule(ASTValidationRule):
2727
schema and legally positioned.
2828
"""
2929

30+
context: Union[ValidationContext, SDLValidationContext]
31+
3032
def __init__(self, context: Union[
3133
ValidationContext, SDLValidationContext]) -> None:
3234
super().__init__(context)
33-
schema = context.schema
3435
locations_map: Dict[str, List[DirectiveLocation]] = {}
36+
37+
schema = context.schema
3538
defined_directives = (
3639
schema.directives if schema else cast(List, specified_directives))
3740
for directive in defined_directives:

graphql/validation/rules/lone_schema_definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from . import SDLValidationRule, SDLValidationContext
44

55
__all__ = [
6-
'LoneSchemaDefinition',
6+
'LoneSchemaDefinitionRule',
77
'schema_definition_not_alone_message',
88
'cannot_define_schema_within_extension_message']
99

@@ -16,7 +16,7 @@ def cannot_define_schema_within_extension_message():
1616
return 'Cannot define a new schema within a schema extension.'
1717

1818

19-
class LoneSchemaDefinition(SDLValidationRule):
19+
class LoneSchemaDefinitionRule(SDLValidationRule):
2020
"""Lone Schema definition
2121
2222
A GraphQL document is only valid if it contains only one schema definition.
Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
from typing import Dict, Union
2+
13
from ...error import GraphQLError
2-
from ...language import DirectiveNode, FieldNode
3-
from ...type import is_required_argument
4-
from . import ValidationRule
4+
from ...language import (
5+
DirectiveDefinitionNode, DirectiveNode, FieldNode,
6+
InputValueDefinitionNode, NonNullTypeNode, print_ast)
7+
from ...type import (
8+
GraphQLArgument, is_required_argument, is_type, specified_directives)
9+
from . import ASTValidationRule, SDLValidationContext, ValidationContext
510

611
__all__ = [
712
'ProvidedRequiredArgumentsRule',
13+
'ProvidedRequiredArgumentsOnDirectivesRule',
814
'missing_field_arg_message', 'missing_directive_arg_message']
915

1016

@@ -20,37 +26,83 @@ def missing_directive_arg_message(
2026
f" of type '{type_}' is required but not provided.")
2127

2228

23-
class ProvidedRequiredArgumentsRule(ValidationRule):
29+
class ProvidedRequiredArgumentsOnDirectivesRule(ASTValidationRule):
30+
"""Provided required arguments on directives
31+
32+
A directive is only valid if all required (non-null without a
33+
default value) arguments have been provided.
34+
"""
35+
36+
context: Union[ValidationContext, SDLValidationContext]
37+
38+
def __init__(self, context: Union[
39+
ValidationContext, SDLValidationContext]) -> None:
40+
super().__init__(context)
41+
required_args_map: Dict[str, Dict[str, GraphQLArgument]] = {}
42+
43+
schema = context.schema
44+
defined_directives = (
45+
schema.directives if schema else specified_directives)
46+
for directive in defined_directives:
47+
required_args_map[directive.name] = {
48+
name: arg for name, arg in directive.args.items()
49+
if is_required_argument(arg)}
50+
51+
ast_definitions = context.document.definitions
52+
for def_ in ast_definitions:
53+
if isinstance(def_, DirectiveDefinitionNode):
54+
required_args_map[def_.name.value] = {
55+
arg.name.value: arg for arg in filter(
56+
is_required_argument_node, def_.arguments)
57+
} if def_.arguments else {}
58+
59+
self.required_args_map = required_args_map
60+
61+
def leave_directive(self, directive_node: DirectiveNode, *_args):
62+
# Validate on leave to allow for deeper errors to appear first.
63+
directive_name = directive_node.name.value
64+
required_args = self.required_args_map.get(directive_name)
65+
if required_args:
66+
67+
arg_nodes = directive_node.arguments or []
68+
arg_node_set = {arg.name.value for arg in arg_nodes}
69+
for arg_name in required_args:
70+
if arg_name not in arg_node_set:
71+
arg_type = required_args[arg_name].type
72+
self.report_error(GraphQLError(
73+
missing_directive_arg_message(
74+
directive_name, arg_name, str(arg_type)
75+
if is_type(arg_type) else print_ast(arg_type)),
76+
[directive_node]))
77+
78+
79+
class ProvidedRequiredArgumentsRule(ProvidedRequiredArgumentsOnDirectivesRule):
2480
"""Provided required arguments
2581
2682
A field or directive is only valid if all required (non-null without a
2783
default value) field arguments have been provided.
2884
"""
2985

30-
def leave_field(self, node: FieldNode, *_args):
86+
context: ValidationContext
87+
88+
def __init__(self, context: ValidationContext) -> None:
89+
super().__init__(context)
90+
91+
def leave_field(self, field_node: FieldNode, *_args):
3192
# Validate on leave to allow for deeper errors to appear first.
3293
field_def = self.context.get_field_def()
3394
if not field_def:
3495
return self.SKIP
35-
arg_nodes = node.arguments or []
96+
arg_nodes = field_node.arguments or []
3697

3798
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
3899
for arg_name, arg_def in field_def.args.items():
39100
arg_node = arg_node_map.get(arg_name)
40101
if not arg_node and is_required_argument(arg_def):
41102
self.report_error(GraphQLError(missing_field_arg_message(
42-
node.name.value, arg_name, str(arg_def.type)), [node]))
103+
field_node.name.value, arg_name, str(arg_def.type)),
104+
[field_node]))
43105

44-
def leave_directive(self, node: DirectiveNode, *_args):
45-
# Validate on leave to allow for deeper errors to appear first.
46-
directive_def = self.context.get_directive()
47-
if not directive_def:
48-
return False
49-
arg_nodes = node.arguments or []
50106

51-
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
52-
for arg_name, arg_def in directive_def.args.items():
53-
arg_node = arg_node_map.get(arg_name)
54-
if not arg_node and is_required_argument(arg_def):
55-
self.report_error(GraphQLError(missing_directive_arg_message(
56-
node.name.value, arg_name, str(arg_def.type)), [node]))
107+
def is_required_argument_node(arg: InputValueDefinitionNode) -> bool:
108+
return isinstance(arg.type, NonNullTypeNode) and arg.default_value is None

graphql/validation/specified_rules.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@
8383
from .rules.unique_input_field_names import UniqueInputFieldNamesRule
8484

8585
# Schema definition language:
86-
from .rules.lone_schema_definition import LoneSchemaDefinition
86+
from .rules.lone_schema_definition import LoneSchemaDefinitionRule
87+
from .rules.known_argument_names import KnownArgumentNamesOnDirectivesRule
88+
from .rules.provided_required_arguments import (
89+
ProvidedRequiredArgumentsOnDirectivesRule)
8790

8891
__all__ = ['specified_rules', 'specified_sdl_rules']
8992

@@ -122,8 +125,10 @@
122125
UniqueInputFieldNamesRule]
123126

124127
specified_sdl_rules: List[RuleType] = [
125-
LoneSchemaDefinition,
128+
LoneSchemaDefinitionRule,
126129
KnownDirectivesRule,
127130
UniqueDirectivesPerLocationRule,
131+
KnownArgumentNamesOnDirectivesRule,
128132
UniqueArgumentNamesRule,
129-
UniqueInputFieldNamesRule]
133+
UniqueInputFieldNamesRule,
134+
ProvidedRequiredArgumentsOnDirectivesRule]

0 commit comments

Comments
 (0)