Skip to content

Commit d07e86d

Browse files
committed
Implement OverlappingFieldsCanBeMerged
* Fix errors in `type_map_reducer`. * Implement `expect_{passes,fails}_rule_with_schema` * Add support to opt out of argument sorting in `expect_invalid`. * Implement `PairSet` and `DefaultOrderedDict`. * Warn if schema does not contain unique named types. * Implement `is_same_type` onto `GraphQLType`, `GraphQLList` and `GraphQLNonNull`.
1 parent d1c04e7 commit d07e86d

File tree

6 files changed

+719
-13
lines changed

6 files changed

+719
-13
lines changed

graphql/core/type/definition.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class GraphQLType(object):
4646
def __str__(self):
4747
return self.name
4848

49+
def is_same_type(self, other):
50+
return self.__class__ is other.__class__ and self.name == other.name
51+
4952

5053
class GraphQLScalarType(GraphQLType):
5154
"""Scalar Type Definition
@@ -440,6 +443,9 @@ def __init__(self, type):
440443
def __str__(self):
441444
return '[' + str(self.of_type) + ']'
442445

446+
def is_same_type(self, other):
447+
return isinstance(other, GraphQLList) and self.of_type.is_same_type(other.of_type)
448+
443449

444450
class GraphQLNonNull(GraphQLType):
445451
"""Non-Null Modifier
@@ -465,3 +471,6 @@ def __init__(self, type):
465471

466472
def __str__(self):
467473
return str(self.of_type) + '!'
474+
475+
def is_same_type(self, other):
476+
return isinstance(other, GraphQLNonNull) and self.of_type.is_same_type(other.of_type)

graphql/core/type/schema.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from functools import reduce
2+
from py._log import warning
23
from .definition import (
34
GraphQLObjectType,
45
GraphQLInterfaceType,
56
GraphQLUnionType,
67
GraphQLList,
78
GraphQLNonNull,
9+
GraphQLInputObjectType
810
)
911
from .introspection import IntrospectionSchema
1012
from .directives import GraphQLIncludeDirective, GraphQLSkipDirective
@@ -67,10 +69,18 @@ def _build_type_map(self):
6769

6870

6971
def type_map_reducer(map, type):
72+
if not type:
73+
return map
74+
7075
if isinstance(type, GraphQLList) or isinstance(type, GraphQLNonNull):
7176
return type_map_reducer(map, type.of_type)
7277

73-
if not type or type.name in map:
78+
if type.name in map:
79+
if map[type.name] != type:
80+
warning.warn(
81+
'Schema must contain unique named types but contains multiple types named "{}".'
82+
.format(type.name)
83+
)
7484
return map
7585
map[type.name] = type
7686

@@ -86,13 +96,15 @@ def type_map_reducer(map, type):
8696
type_map_reducer, type.get_interfaces(), reduced_map
8797
)
8898

89-
if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType)):
99+
if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLInputObjectType)):
90100
field_map = type.get_fields()
91101
for field_name, field in field_map.items():
92-
field_arg_types = [arg.type for arg in field.args]
93-
reduced_map = reduce(
94-
type_map_reducer, field_arg_types, reduced_map
95-
)
96-
reduced_map = type_map_reducer(reduced_map, field.type)
102+
if hasattr(field, 'args'):
103+
field_arg_types = [arg.type for arg in field.args]
104+
reduced_map = reduce(
105+
type_map_reducer, field_arg_types, reduced_map
106+
)
107+
108+
reduced_map = type_map_reducer(reduced_map, getattr(field, 'type', None))
97109

98110
return reduced_map

graphql/core/validation/rules.py

Lines changed: 229 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import itertools
12
from ..utils import type_from_ast, is_valid_literal_value
3+
from .utils import PairSet, DefaultOrderedDict
24
from ..error import GraphQLError
3-
from ..type.definition import is_composite_type, is_input_type, is_leaf_type, GraphQLNonNull
5+
from ..type.definition import is_composite_type, is_input_type, is_leaf_type, get_named_type, GraphQLNonNull, \
6+
GraphQLObjectType, GraphQLInterfaceType
47
from ..language import ast
58
from ..language.visitor import Visitor, visit
69
from ..language.printer import print_ast
@@ -627,4 +630,228 @@ class VariablesInAllowedPosition(ValidationRule):
627630

628631

629632
class OverlappingFieldsCanBeMerged(ValidationRule):
630-
pass
633+
def __init__(self, context):
634+
super(OverlappingFieldsCanBeMerged, self).__init__(context)
635+
self.compared_set = PairSet()
636+
637+
def find_conflicts(self, field_map):
638+
conflicts = []
639+
for response_name, fields in field_map.items():
640+
field_len = len(fields)
641+
if field_len <= 1:
642+
continue
643+
644+
for field_a in fields:
645+
for field_b in fields:
646+
conflict = self.find_conflict(response_name, field_a, field_b)
647+
if conflict:
648+
conflicts.append(conflict)
649+
650+
return conflicts
651+
652+
@staticmethod
653+
def ast_to_hashable(ast):
654+
"""
655+
This function will take an AST, and return a portion of it that is unique enough to identify the AST,
656+
but without the unhashable bits.
657+
"""
658+
if not ast:
659+
return None
660+
661+
return ast.__class__, ast.loc['start'], ast.loc['end']
662+
663+
def find_conflict(self, response_name, pair1, pair2):
664+
ast1, def1 = pair1
665+
ast2, def2 = pair2
666+
667+
ast1_hashable = self.ast_to_hashable(ast1)
668+
ast2_hashable = self.ast_to_hashable(ast2)
669+
670+
if ast1 is ast2 or self.compared_set.has(ast1_hashable, ast2_hashable):
671+
return
672+
673+
self.compared_set.add(ast1_hashable, ast2_hashable)
674+
675+
name1 = ast1.name.value
676+
name2 = ast2.name.value
677+
678+
if name1 != name2:
679+
return (
680+
(response_name, '{} and {} are different fields'.format(name1, name2)),
681+
(ast1, ast2)
682+
)
683+
684+
type1 = def1 and def1.type
685+
type2 = def2 and def2.type
686+
687+
if type1 and type2 and not self.same_type(type1, type2):
688+
return (
689+
(response_name, 'they return differing types {} and {}'.format(type1, type2)),
690+
(ast1, ast2)
691+
)
692+
693+
if not self.same_arguments(ast1.arguments, ast2.arguments):
694+
return (
695+
(response_name, 'they have differing arguments'),
696+
(ast1, ast2)
697+
)
698+
699+
if not self.same_directives(ast1.directives, ast2.directives):
700+
return (
701+
(response_name, 'they have differing directives'),
702+
(ast1, ast2)
703+
)
704+
705+
selection_set1 = ast1.selection_set
706+
selection_set2 = ast2.selection_set
707+
708+
if selection_set1 and selection_set2:
709+
visited_fragment_names = set()
710+
711+
subfield_map = self.collect_field_asts_and_defs(
712+
get_named_type(type1),
713+
selection_set1,
714+
visited_fragment_names
715+
)
716+
717+
subfield_map = self.collect_field_asts_and_defs(
718+
get_named_type(type2),
719+
selection_set2,
720+
visited_fragment_names,
721+
subfield_map
722+
)
723+
724+
conflicts = self.find_conflicts(subfield_map)
725+
if conflicts:
726+
return (
727+
(response_name, [conflict[0] for conflict in conflicts]),
728+
tuple(itertools.chain((ast1, ast2), *[conflict[1] for conflict in conflicts]))
729+
)
730+
731+
def leave_SelectionSet(self, node, *args):
732+
field_map = self.collect_field_asts_and_defs(
733+
self.context.get_parent_type(),
734+
node
735+
)
736+
737+
conflicts = self.find_conflicts(field_map)
738+
if conflicts:
739+
return [
740+
GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields)) for
741+
(reason_name, reason), fields in conflicts
742+
]
743+
744+
@staticmethod
745+
def same_type(type1, type2):
746+
return type1.is_same_type(type2)
747+
748+
@staticmethod
749+
def same_value(value1, value2):
750+
return (not value1 and not value2) or print_ast(value1) == print_ast(value2)
751+
752+
@classmethod
753+
def same_arguments(cls, arguments1, arguments2):
754+
# Check to see if they are empty arguments or nones. If they are, we can
755+
# bail out early.
756+
if not (arguments1 and arguments2):
757+
return True
758+
759+
if len(arguments1) != len(arguments2):
760+
return False
761+
762+
arguments2_values_to_arg = {a.name.value: a for a in arguments2}
763+
764+
for argument1 in arguments1:
765+
argument2 = arguments2_values_to_arg.get(argument1.name.value)
766+
if not argument2:
767+
return False
768+
769+
if not cls.same_value(argument1.value, argument2.value):
770+
return False
771+
772+
return True
773+
774+
@classmethod
775+
def same_directives(cls, directives1, directives2):
776+
# Check to see if they are empty directives or nones. If they are, we can
777+
# bail out early.
778+
if not (directives1 and directives2):
779+
return True
780+
781+
if len(directives1) != len(directives2):
782+
return False
783+
784+
directives2_values_to_arg = {a.name.value: a for a in directives2}
785+
786+
for directive1 in directives1:
787+
directive2 = directives2_values_to_arg.get(directive1.name.value)
788+
if not directive2:
789+
return False
790+
791+
if not cls.same_arguments(directive1.arguments, directive2.arguments):
792+
return False
793+
794+
return True
795+
796+
def collect_field_asts_and_defs(self, parent_type, selection_set, visited_fragment_names=None, ast_and_defs=None):
797+
if visited_fragment_names is None:
798+
visited_fragment_names = set()
799+
800+
if ast_and_defs is None:
801+
# An ordered dictionary is required, otherwise the error message will be out of order.
802+
# We need to preserve the order that the item was inserted into the dict, as that will dictate
803+
# in which order the reasons in the error message should show.
804+
# Otherwise, the error messages will be inconsistently ordered for the same AST.
805+
# And this can make it so that tests fail half the time, and fool a user into thinking that
806+
# the errors are different, when in-fact they are the same, just that the ordering of the reasons differ.
807+
ast_and_defs = DefaultOrderedDict(list)
808+
809+
for selection in selection_set.selections:
810+
if isinstance(selection, ast.Field):
811+
field_name = selection.name.value
812+
field_def = None
813+
if isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)):
814+
field_def = parent_type.get_fields().get(field_name)
815+
816+
response_name = selection.alias.value if selection.alias else field_name
817+
ast_and_defs[response_name].append((selection, field_def))
818+
819+
elif isinstance(selection, ast.InlineFragment):
820+
self.collect_field_asts_and_defs(
821+
type_from_ast(self.context.get_schema(), selection.type_condition),
822+
selection.selection_set,
823+
visited_fragment_names,
824+
ast_and_defs
825+
)
826+
827+
elif isinstance(selection, ast.FragmentSpread):
828+
fragment_name = selection.name.value
829+
if fragment_name in visited_fragment_names:
830+
continue
831+
832+
visited_fragment_names.add(fragment_name)
833+
fragment = self.context.get_fragment(fragment_name)
834+
835+
if not fragment:
836+
continue
837+
838+
self.collect_field_asts_and_defs(
839+
type_from_ast(self.context.get_schema(), fragment.type_condition),
840+
fragment.selection_set,
841+
visited_fragment_names,
842+
ast_and_defs
843+
)
844+
845+
return ast_and_defs
846+
847+
@classmethod
848+
def fields_conflict_message(cls, reason_name, reason):
849+
return 'Fields "{}" conflict because {}'.format(reason_name, cls.reason_message(reason))
850+
851+
@classmethod
852+
def reason_message(cls, reason):
853+
if isinstance(reason, list):
854+
return ' and '.join('subfields "{}" conflict because {}'.format(reason_name, cls.reason_message(sub_reason))
855+
for reason_name, sub_reason in reason)
856+
857+
return reason

graphql/core/validation/utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from collections import OrderedDict, Callable
2+
3+
4+
class PairSet(object):
5+
def __init__(self):
6+
self._data = set()
7+
8+
def __contains__(self, item):
9+
return item in self._data
10+
11+
def has(self, a, b):
12+
return (a, b) in self._data
13+
14+
def add(self, a, b):
15+
self._data.add((a, b))
16+
self._data.add((b, a))
17+
return self
18+
19+
def remove(self, a, b):
20+
self._data.discard((a, b))
21+
self._data.discard((b, a))
22+
23+
24+
class DefaultOrderedDict(OrderedDict):
25+
# Source: http://stackoverflow.com/a/6190500/562769
26+
def __init__(self, default_factory=None, *a, **kw):
27+
if (default_factory is not None and
28+
not isinstance(default_factory, Callable)):
29+
raise TypeError('first argument must be callable')
30+
OrderedDict.__init__(self, *a, **kw)
31+
self.default_factory = default_factory
32+
33+
def __getitem__(self, key):
34+
try:
35+
return OrderedDict.__getitem__(self, key)
36+
except KeyError:
37+
return self.__missing__(key)
38+
39+
def __missing__(self, key):
40+
if self.default_factory is None:
41+
raise KeyError(key)
42+
self[key] = value = self.default_factory()
43+
return value
44+
45+
def __reduce__(self):
46+
if self.default_factory is None:
47+
args = tuple()
48+
else:
49+
args = self.default_factory,
50+
return type(self), args, None, None, self.items()
51+
52+
def copy(self):
53+
return self.__copy__()
54+
55+
def __copy__(self):
56+
return type(self)(self.default_factory, self)
57+
58+
def __deepcopy__(self, memo):
59+
import copy
60+
return type(self)(self.default_factory,
61+
copy.deepcopy(self.items()))
62+
63+
def __repr__(self, _repr_running=None):
64+
if _repr_running is None:
65+
_repr_running = {}
66+
67+
return 'OrderedDefaultDict(%s, %s)' % (self.default_factory,
68+
OrderedDict.__repr__(self, _repr_running))

0 commit comments

Comments
 (0)