Skip to content

Commit 006a76f

Browse files
committed
Merge pull request #39 from jhgg/master
Clean ups
2 parents 3aae23a + cfd2e66 commit 006a76f

11 files changed

+122
-103
lines changed

graphql/core/validation/rules.py

Lines changed: 105 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -18,75 +18,79 @@ def __init__(self, context):
1818

1919
def enter_OperationDefinition(self, node, *args):
2020
operation_name = node.name
21-
if operation_name:
22-
if operation_name.value in self.known_operation_names:
23-
return GraphQLError(
24-
self.message(operation_name.value),
25-
[self.known_operation_names[operation_name.value], operation_name]
26-
)
27-
self.known_operation_names[operation_name.value] = operation_name
21+
if not operation_name:
22+
return
23+
24+
if operation_name.value in self.known_operation_names:
25+
return GraphQLError(
26+
self.duplicate_operation_name_message(operation_name.value),
27+
[self.known_operation_names[operation_name.value], operation_name]
28+
)
29+
30+
self.known_operation_names[operation_name.value] = operation_name
2831

2932
@staticmethod
30-
def message(operation_name):
33+
def duplicate_operation_name_message(operation_name):
3134
return 'There can only be one operation named "{}".'.format(operation_name)
3235

3336

3437
class LoneAnonymousOperation(ValidationRule):
38+
operation_count = 0
39+
3540
def __init__(self, context):
3641
super(LoneAnonymousOperation, self).__init__(context)
37-
self._op_count = 0
3842

3943
def enter_Document(self, node, *args):
40-
n = 0
41-
for definition in node.definitions:
42-
if isinstance(definition, ast.OperationDefinition):
43-
n += 1
44-
self._op_count = n
44+
self.operation_count = \
45+
sum(1 for definition in node.definitions if isinstance(definition, ast.OperationDefinition))
4546

4647
def enter_OperationDefinition(self, node, *args):
47-
if not node.name and self._op_count > 1:
48-
return GraphQLError(self.message(), [node])
48+
if not node.name and self.operation_count > 1:
49+
return GraphQLError(self.anonymous_operation_not_alone_message(), [node])
4950

5051
@staticmethod
51-
def message():
52+
def anonymous_operation_not_alone_message():
5253
return 'This anonymous operation must be the only defined operation.'
5354

5455

5556
class KnownTypeNames(ValidationRule):
5657
def enter_NamedType(self, node, *args):
5758
type_name = node.name.value
5859
type = self.context.get_schema().get_type(type_name)
60+
5961
if not type:
60-
return GraphQLError(self.message(type_name), [node])
62+
return GraphQLError(self.unknown_type_message(type_name), [node])
6163

6264
@staticmethod
63-
def message(type):
65+
def unknown_type_message(type):
6466
return 'Unknown type "{}".'.format(type)
6567

6668

6769
class FragmentsOnCompositeTypes(ValidationRule):
6870
def enter_InlineFragment(self, node, *args):
6971
type = self.context.get_type()
72+
7073
if type and not is_composite_type(type):
7174
return GraphQLError(
72-
self.inline_message(print_ast(node.type_condition)),
75+
self.inline_fragment_on_non_composite_error_message(print_ast(node.type_condition)),
7376
[node.type_condition]
7477
)
7578

7679
def enter_FragmentDefinition(self, node, *args):
7780
type = self.context.get_type()
81+
7882
if type and not is_composite_type(type):
7983
return GraphQLError(
80-
self.message(node.name.value, print_ast(node.type_condition)),
84+
self.fragment_on_non_composite_error_message(node.name.value, print_ast(node.type_condition)),
8185
[node.type_condition]
8286
)
8387

8488
@staticmethod
85-
def inline_message(type):
89+
def inline_fragment_on_non_composite_error_message(type):
8690
return 'Fragment cannot condition on non composite type "{}".'.format(type)
8791

8892
@staticmethod
89-
def message(frag_name, type):
93+
def fragment_on_non_composite_error_message(frag_name, type):
9094
return 'Fragment "{}" cannot condition on non composite type "{}".'.format(frag_name, type)
9195

9296

@@ -95,55 +99,60 @@ def enter_VariableDefinition(self, node, *args):
9599
type = type_from_ast(self.context.get_schema(), node.type)
96100

97101
if type and not is_input_type(type):
98-
variable_name = node.variable.name.value
99102
return GraphQLError(
100-
self.message(variable_name, print_ast(node.type)),
103+
self.non_input_type_on_variable_message(node.variable.name.value, print_ast(node.type)),
101104
[node.type]
102105
)
103106

104107
@staticmethod
105-
def message(variable_name, type_name):
108+
def non_input_type_on_variable_message(variable_name, type_name):
106109
return 'Variable "${}" cannot be non-input type "{}".'.format(variable_name, type_name)
107110

108111

109112
class ScalarLeafs(ValidationRule):
110113
def enter_Field(self, node, *args):
111114
type = self.context.get_type()
112-
if type:
113-
if is_leaf_type(type):
114-
if node.selection_set:
115-
return GraphQLError(
116-
self.not_allowed_message(node.name.value, type),
117-
[node.selection_set]
118-
)
119-
elif not node.selection_set:
115+
116+
if not type:
117+
return
118+
119+
if is_leaf_type(type):
120+
if node.selection_set:
120121
return GraphQLError(
121-
self.required_message(node.name.value, type),
122-
[node]
122+
self.no_subselection_allowed_message(node.name.value, type),
123+
[node.selection_set]
123124
)
124125

126+
elif not node.selection_set:
127+
return GraphQLError(
128+
self.required_subselection_message(node.name.value, type),
129+
[node]
130+
)
131+
125132
@staticmethod
126-
def not_allowed_message(field, type):
133+
def no_subselection_allowed_message(field, type):
127134
return 'Field "{}" of type "{}" must not have a sub selection.'.format(field, type)
128135

129136
@staticmethod
130-
def required_message(field, type):
137+
def required_subselection_message(field, type):
131138
return 'Field "{}" of type "{}" must have a sub selection.'.format(field, type)
132139

133140

134141
class FieldsOnCorrectType(ValidationRule):
135142
def enter_Field(self, node, *args):
136143
type = self.context.get_parent_type()
137-
if type:
138-
field_def = self.context.get_field_def()
139-
if not field_def:
140-
return GraphQLError(
141-
self.message(node.name.value, type.name),
142-
[node]
143-
)
144+
if not type:
145+
return
146+
147+
field_def = self.context.get_field_def()
148+
if not field_def:
149+
return GraphQLError(
150+
self.undefined_field_message(node.name.value, type.name),
151+
[node]
152+
)
144153

145154
@staticmethod
146-
def message(field_name, type):
155+
def undefined_field_message(field_name, type):
147156
return 'Cannot query field "{}" on "{}".'.format(field_name, type)
148157

149158

@@ -159,17 +168,19 @@ def enter_FragmentDefinition(self, node, *args):
159168
self.duplicate_fragment_name_message(fragment_name),
160169
[self.known_fragment_names[fragment_name], node.name]
161170
)
171+
162172
self.known_fragment_names[fragment_name] = node.name
163173

164174
@staticmethod
165175
def duplicate_fragment_name_message(field):
166-
return 'There can only be one fragment named {}'.format(field)
176+
return 'There can only be one fragment named "{}".'.format(field)
167177

168178

169179
class KnownFragmentNames(ValidationRule):
170180
def enter_FragmentSpread(self, node, *args):
171181
fragment_name = node.name.value
172182
fragment = self.context.get_fragment(fragment_name)
183+
173184
if not fragment:
174185
return GraphQLError(
175186
self.unknown_fragment_message(fragment_name),
@@ -310,9 +321,9 @@ def collect_fragment_spread_nodes(self):
310321

311322
class NoUndefinedVariables(ValidationRule):
312323
visit_spread_fragments = True
324+
operation = None
313325

314326
def __init__(self, context):
315-
self.operation = None
316327
self.visited_fragment_names = set()
317328
self.defined_variable_names = set()
318329
super(NoUndefinedVariables, self).__init__(context)
@@ -409,27 +420,31 @@ def unused_variable_message(variable_name):
409420

410421
class KnownDirectives(ValidationRule):
411422
def enter_Directive(self, node, key, parent, path, ancestors):
412-
directive_def = None
413-
for definition in self.context.get_schema().get_directives():
414-
if definition.name == node.name.value:
415-
directive_def = definition
416-
break
423+
directive_def = next((
424+
definition for definition in self.context.get_schema().get_directives()
425+
if definition.name == node.name.value
426+
), None)
427+
417428
if not directive_def:
418429
return GraphQLError(
419-
self.message(node.name.value),
430+
self.unknown_directive_message(node.name.value),
420431
[node]
421432
)
433+
422434
applied_to = ancestors[-1]
435+
423436
if isinstance(applied_to, ast.OperationDefinition) and not directive_def.on_operation:
424437
return GraphQLError(
425438
self.misplaced_directive_message(node.name.value, 'operation'),
426439
[node]
427440
)
441+
428442
if isinstance(applied_to, ast.Field) and not directive_def.on_field:
429443
return GraphQLError(
430444
self.misplaced_directive_message(node.name.value, 'field'),
431445
[node]
432446
)
447+
433448
if (isinstance(applied_to, (ast.FragmentSpread, ast.InlineFragment, ast.FragmentDefinition)) and
434449
not directive_def.on_fragment):
435450
return GraphQLError(
@@ -438,7 +453,7 @@ def enter_Directive(self, node, key, parent, path, ancestors):
438453
)
439454

440455
@staticmethod
441-
def message(directive_name):
456+
def unknown_directive_message(directive_name):
442457
return 'Unknown directive "{}".'.format(directive_name)
443458

444459
@staticmethod
@@ -449,41 +464,41 @@ def misplaced_directive_message(directive_name, placement):
449464
class KnownArgumentNames(ValidationRule):
450465
def enter_Argument(self, node, key, parent, path, ancestors):
451466
argument_of = ancestors[-1]
467+
452468
if isinstance(argument_of, ast.Field):
453469
field_def = self.context.get_field_def()
454-
if field_def:
455-
field_arg_def = None
456-
for arg in field_def.args:
457-
if arg.name == node.name.value:
458-
field_arg_def = arg
459-
break
460-
if not field_arg_def:
461-
parent_type = self.context.get_parent_type()
462-
assert parent_type
463-
return GraphQLError(
464-
self.message(node.name.value, field_def.name, parent_type.name),
465-
[node]
466-
)
470+
if not field_def:
471+
return
472+
473+
field_arg_def = next((arg for arg in field_def.args if arg.name == node.name.value), None)
474+
475+
if not field_arg_def:
476+
parent_type = self.context.get_parent_type()
477+
assert parent_type
478+
return GraphQLError(
479+
self.unknown_arg_message(node.name.value, field_def.name, parent_type.name),
480+
[node]
481+
)
482+
467483
elif isinstance(argument_of, ast.Directive):
468484
directive = self.context.get_directive()
469-
if directive:
470-
directive_arg_def = None
471-
for arg in directive.args:
472-
if arg.name == node.name.value:
473-
directive_arg_def = arg
474-
break
475-
if not directive_arg_def:
476-
return GraphQLError(
477-
self.directive_message(node.name.value, directive.name),
478-
[node]
479-
)
485+
if not directive:
486+
return
487+
488+
directive_arg_def = next((arg for arg in directive.args if arg.name == node.name.value), None)
489+
490+
if not directive_arg_def:
491+
return GraphQLError(
492+
self.unknown_directive_arg_message(node.name.value, directive.name),
493+
[node]
494+
)
480495

481496
@staticmethod
482-
def message(arg_name, field_name, type):
497+
def unknown_arg_message(arg_name, field_name, type):
483498
return 'Unknown argument "{}" on field "{}" of type "{}".'.format(arg_name, field_name, type)
484499

485500
@staticmethod
486-
def directive_message(arg_name, directive_name):
501+
def unknown_directive_arg_message(arg_name, directive_name):
487502
return 'Unknown argument "{}" on directive "@{}".'.format(arg_name, directive_name)
488503

489504

@@ -492,24 +507,26 @@ def __init__(self, context):
492507
super(UniqueArgumentNames, self).__init__(context)
493508
self.known_arg_names = {}
494509

495-
def enter_Field(self, node, *args):
510+
def enter_Field(self, *args):
496511
self.known_arg_names = {}
497512

498-
def enter_Directive(self, node, key, parent, path, ancestors):
513+
def enter_Directive(self, *args):
499514
self.known_arg_names = {}
500515

501516
def enter_Argument(self, node, *args):
502517
arg_name = node.name.value
518+
503519
if arg_name in self.known_arg_names:
504520
return GraphQLError(
505521
self.duplicate_arg_message(arg_name),
506522
[self.known_arg_names[arg_name], node.name]
507523
)
524+
508525
self.known_arg_names[arg_name] = node.name
509526

510527
@staticmethod
511528
def duplicate_arg_message(field):
512-
return 'There can only be one argument named {}'.format(field)
529+
return 'There can only be one argument named "{}".'.format(field)
513530

514531

515532
class ArgumentsOfCorrectType(ValidationRule):
@@ -528,7 +545,7 @@ def bad_value_message(arg_name, type, value):
528545

529546

530547
class ProvidedNonNullArguments(ValidationRule):
531-
def leave_Field(self, node, key, parent, path, ancestors):
548+
def leave_Field(self, node, *args):
532549
field_def = self.context.get_field_def()
533550
if not field_def:
534551
return False
@@ -544,10 +561,11 @@ def leave_Field(self, node, key, parent, path, ancestors):
544561
self.missing_field_arg_message(node.name.value, arg_def.name, arg_def.type),
545562
[node]
546563
))
564+
547565
if errors:
548566
return errors
549567

550-
def leave_Directive(self, node, key, parent, path, ancestors):
568+
def leave_Directive(self, node, *args):
551569
directive_def = self.context.get_directive()
552570
if not directive_def:
553571
return False
@@ -581,6 +599,7 @@ def enter_VariableDefinition(self, node, *args):
581599
name = node.variable.name.value
582600
default_value = node.default_value
583601
type = self.context.get_input_type()
602+
584603
if isinstance(type, GraphQLNonNull) and default_value:
585604
return GraphQLError(
586605
self.default_for_non_null_arg_message(name, type, type.of_type),

0 commit comments

Comments
 (0)