Skip to content

Commit a4cbcf9

Browse files
committed
Merge branch 'jhgg_overlapping-fields-can-be-merged'
2 parents be05925 + 8939f4a commit a4cbcf9

File tree

6 files changed

+737
-18
lines changed

6 files changed

+737
-18
lines changed

graphql/core/type/definition.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class GraphQLType(object):
4646
def __str__(self):
4747
return self.name
4848

49+
def is_same_type(self, other):
50+
return self.__class__ is other.__class__ and self.name == other.name
51+
4952

5053
class GraphQLScalarType(GraphQLType):
5154
"""Scalar Type Definition
@@ -440,6 +443,9 @@ def __init__(self, type):
440443
def __str__(self):
441444
return '[' + str(self.of_type) + ']'
442445

446+
def is_same_type(self, other):
447+
return isinstance(other, GraphQLList) and self.of_type.is_same_type(other.of_type)
448+
443449

444450
class GraphQLNonNull(GraphQLType):
445451
"""Non-Null Modifier
@@ -465,3 +471,6 @@ def __init__(self, type):
465471

466472
def __str__(self):
467473
return str(self.of_type) + '!'
474+
475+
def is_same_type(self, other):
476+
return isinstance(other, GraphQLNonNull) and self.of_type.is_same_type(other.of_type)

graphql/core/type/schema.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
GraphQLUnionType,
66
GraphQLList,
77
GraphQLNonNull,
8+
GraphQLInputObjectType
89
)
910
from .introspection import IntrospectionSchema
1011
from .directives import GraphQLIncludeDirective, GraphQLSkipDirective
@@ -67,10 +68,17 @@ def _build_type_map(self):
6768

6869

6970
def type_map_reducer(map, type):
71+
if not type:
72+
return map
73+
7074
if isinstance(type, GraphQLList) or isinstance(type, GraphQLNonNull):
7175
return type_map_reducer(map, type.of_type)
7276

73-
if not type or type.name in map:
77+
if type.name in map:
78+
assert map[type.name] == type, (
79+
'Schema must contain unique named types but contains multiple types named "{}".'
80+
.format(type.name)
81+
)
7482
return map
7583
map[type.name] = type
7684

@@ -86,13 +94,15 @@ def type_map_reducer(map, type):
8694
type_map_reducer, type.get_interfaces(), reduced_map
8795
)
8896

89-
if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType)):
97+
if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLInputObjectType)):
9098
field_map = type.get_fields()
9199
for field_name, field in field_map.items():
92-
field_arg_types = [arg.type for arg in field.args]
93-
reduced_map = reduce(
94-
type_map_reducer, field_arg_types, reduced_map
95-
)
96-
reduced_map = type_map_reducer(reduced_map, field.type)
100+
if hasattr(field, 'args'):
101+
field_arg_types = [arg.type for arg in field.args]
102+
reduced_map = reduce(
103+
type_map_reducer, field_arg_types, reduced_map
104+
)
105+
106+
reduced_map = type_map_reducer(reduced_map, getattr(field, 'type', None))
97107

98108
return reduced_map

graphql/core/validation/rules.py

Lines changed: 238 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import itertools
12
from ..utils import type_from_ast, is_valid_literal_value
3+
from .utils import PairSet, DefaultOrderedDict
24
from ..error import GraphQLError
35
from ..type.definition import (
46
is_composite_type,
57
is_input_type,
68
is_leaf_type,
9+
get_named_type,
710
GraphQLNonNull,
811
GraphQLList,
912
GraphQLObjectType,
@@ -243,7 +246,7 @@ def reduce_spread_fragments(spreads):
243246
)
244247
for fragment_definition in self.fragment_definitions
245248
if fragment_definition.name.value not in fragment_names_used
246-
]
249+
]
247250

248251
if errors:
249252
return errors
@@ -295,11 +298,14 @@ def do_types_overlap(t1, t2):
295298

296299
@staticmethod
297300
def type_incompatible_spread_message(frag_name, parent_type, frag_type):
298-
return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}'.format(frag_name, parent_type, frag_type)
301+
return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}'.format(frag_name,
302+
parent_type,
303+
frag_type)
299304

300305
@staticmethod
301306
def type_incompatible_anon_spread_message(parent_type, frag_type):
302-
return 'Fragment cannot be spread here as objects of type {} can never be of type {}'.format(parent_type, frag_type)
307+
return 'Fragment cannot be spread here as objects of type {} can never be of type {}'.format(parent_type,
308+
frag_type)
303309

304310

305311
class NoFragmentCycles(ValidationRule):
@@ -309,7 +315,7 @@ def __init__(self, context):
309315
node.name.value: self.gather_spreads(node)
310316
for node in context.get_ast().definitions
311317
if isinstance(node, ast.FragmentDefinition)
312-
}
318+
}
313319
self.known_to_lead_to_cycle = set()
314320

315321
def enter_FragmentDefinition(self, node, *args):
@@ -444,7 +450,7 @@ def leave_OperationDefinition(self, *args):
444450
)
445451
for variable_definition in self.variable_definitions
446452
if variable_definition.variable.name.value not in self.variable_name_used
447-
]
453+
]
448454

449455
if errors:
450456
return errors
@@ -731,8 +737,233 @@ def var_type_allowed_for_type(cls, var_type, expected_type):
731737

732738
@staticmethod
733739
def bad_var_pos_message(var_name, var_type, expected_type):
734-
return 'Variable "{}" of type "{}" used in position expecting type "{}".'.format(var_name, var_type, expected_type)
740+
return 'Variable "{}" of type "{}" used in position expecting type "{}".'.format(var_name, var_type,
741+
expected_type)
735742

736743

737744
class OverlappingFieldsCanBeMerged(ValidationRule):
738-
pass
745+
def __init__(self, context):
746+
super(OverlappingFieldsCanBeMerged, self).__init__(context)
747+
self.compared_set = PairSet()
748+
749+
def find_conflicts(self, field_map):
750+
conflicts = []
751+
for response_name, fields in field_map.items():
752+
field_len = len(fields)
753+
if field_len <= 1:
754+
continue
755+
756+
for field_a in fields:
757+
for field_b in fields:
758+
conflict = self.find_conflict(response_name, field_a, field_b)
759+
if conflict:
760+
conflicts.append(conflict)
761+
762+
return conflicts
763+
764+
@staticmethod
765+
def ast_to_hashable(ast):
766+
"""
767+
This function will take an AST, and return a portion of it that is unique enough to identify the AST,
768+
but without the unhashable bits.
769+
"""
770+
if not ast:
771+
return None
772+
773+
return ast.__class__, ast.loc['start'], ast.loc['end']
774+
775+
def find_conflict(self, response_name, pair1, pair2):
776+
ast1, def1 = pair1
777+
ast2, def2 = pair2
778+
779+
ast1_hashable = self.ast_to_hashable(ast1)
780+
ast2_hashable = self.ast_to_hashable(ast2)
781+
782+
if ast1 is ast2 or self.compared_set.has(ast1_hashable, ast2_hashable):
783+
return
784+
785+
self.compared_set.add(ast1_hashable, ast2_hashable)
786+
787+
name1 = ast1.name.value
788+
name2 = ast2.name.value
789+
790+
if name1 != name2:
791+
return (
792+
(response_name, '{} and {} are different fields'.format(name1, name2)),
793+
(ast1, ast2)
794+
)
795+
796+
type1 = def1 and def1.type
797+
type2 = def2 and def2.type
798+
799+
if type1 and type2 and not self.same_type(type1, type2):
800+
return (
801+
(response_name, 'they return differing types {} and {}'.format(type1, type2)),
802+
(ast1, ast2)
803+
)
804+
805+
if not self.same_arguments(ast1.arguments, ast2.arguments):
806+
return (
807+
(response_name, 'they have differing arguments'),
808+
(ast1, ast2)
809+
)
810+
811+
if not self.same_directives(ast1.directives, ast2.directives):
812+
return (
813+
(response_name, 'they have differing directives'),
814+
(ast1, ast2)
815+
)
816+
817+
selection_set1 = ast1.selection_set
818+
selection_set2 = ast2.selection_set
819+
820+
if selection_set1 and selection_set2:
821+
visited_fragment_names = set()
822+
823+
subfield_map = self.collect_field_asts_and_defs(
824+
get_named_type(type1),
825+
selection_set1,
826+
visited_fragment_names
827+
)
828+
829+
subfield_map = self.collect_field_asts_and_defs(
830+
get_named_type(type2),
831+
selection_set2,
832+
visited_fragment_names,
833+
subfield_map
834+
)
835+
836+
conflicts = self.find_conflicts(subfield_map)
837+
if conflicts:
838+
return (
839+
(response_name, [conflict[0] for conflict in conflicts]),
840+
tuple(itertools.chain((ast1, ast2), *[conflict[1] for conflict in conflicts]))
841+
)
842+
843+
def leave_SelectionSet(self, node, *args):
844+
field_map = self.collect_field_asts_and_defs(
845+
self.context.get_parent_type(),
846+
node
847+
)
848+
849+
conflicts = self.find_conflicts(field_map)
850+
if conflicts:
851+
return [
852+
GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields)) for
853+
(reason_name, reason), fields in conflicts
854+
]
855+
856+
@staticmethod
857+
def same_type(type1, type2):
858+
return type1.is_same_type(type2)
859+
860+
@staticmethod
861+
def same_value(value1, value2):
862+
return (not value1 and not value2) or print_ast(value1) == print_ast(value2)
863+
864+
@classmethod
865+
def same_arguments(cls, arguments1, arguments2):
866+
# Check to see if they are empty arguments or nones. If they are, we can
867+
# bail out early.
868+
if not (arguments1 or arguments2):
869+
return True
870+
871+
if len(arguments1) != len(arguments2):
872+
return False
873+
874+
arguments2_values_to_arg = {a.name.value: a for a in arguments2}
875+
876+
for argument1 in arguments1:
877+
argument2 = arguments2_values_to_arg.get(argument1.name.value)
878+
if not argument2:
879+
return False
880+
881+
if not cls.same_value(argument1.value, argument2.value):
882+
return False
883+
884+
return True
885+
886+
@classmethod
887+
def same_directives(cls, directives1, directives2):
888+
# Check to see if they are empty directives or nones. If they are, we can
889+
# bail out early.
890+
if not (directives1 or directives2):
891+
return True
892+
893+
if len(directives1) != len(directives2):
894+
return False
895+
896+
directives2_values_to_arg = {a.name.value: a for a in directives2}
897+
898+
for directive1 in directives1:
899+
directive2 = directives2_values_to_arg.get(directive1.name.value)
900+
if not directive2:
901+
return False
902+
903+
if not cls.same_arguments(directive1.arguments, directive2.arguments):
904+
return False
905+
906+
return True
907+
908+
def collect_field_asts_and_defs(self, parent_type, selection_set, visited_fragment_names=None, ast_and_defs=None):
909+
if visited_fragment_names is None:
910+
visited_fragment_names = set()
911+
912+
if ast_and_defs is None:
913+
# An ordered dictionary is required, otherwise the error message will be out of order.
914+
# We need to preserve the order that the item was inserted into the dict, as that will dictate
915+
# in which order the reasons in the error message should show.
916+
# Otherwise, the error messages will be inconsistently ordered for the same AST.
917+
# And this can make it so that tests fail half the time, and fool a user into thinking that
918+
# the errors are different, when in-fact they are the same, just that the ordering of the reasons differ.
919+
ast_and_defs = DefaultOrderedDict(list)
920+
921+
for selection in selection_set.selections:
922+
if isinstance(selection, ast.Field):
923+
field_name = selection.name.value
924+
field_def = None
925+
if isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)):
926+
field_def = parent_type.get_fields().get(field_name)
927+
928+
response_name = selection.alias.value if selection.alias else field_name
929+
ast_and_defs[response_name].append((selection, field_def))
930+
931+
elif isinstance(selection, ast.InlineFragment):
932+
self.collect_field_asts_and_defs(
933+
type_from_ast(self.context.get_schema(), selection.type_condition),
934+
selection.selection_set,
935+
visited_fragment_names,
936+
ast_and_defs
937+
)
938+
939+
elif isinstance(selection, ast.FragmentSpread):
940+
fragment_name = selection.name.value
941+
if fragment_name in visited_fragment_names:
942+
continue
943+
944+
visited_fragment_names.add(fragment_name)
945+
fragment = self.context.get_fragment(fragment_name)
946+
947+
if not fragment:
948+
continue
949+
950+
self.collect_field_asts_and_defs(
951+
type_from_ast(self.context.get_schema(), fragment.type_condition),
952+
fragment.selection_set,
953+
visited_fragment_names,
954+
ast_and_defs
955+
)
956+
957+
return ast_and_defs
958+
959+
@classmethod
960+
def fields_conflict_message(cls, reason_name, reason):
961+
return 'Fields "{}" conflict because {}'.format(reason_name, cls.reason_message(reason))
962+
963+
@classmethod
964+
def reason_message(cls, reason):
965+
if isinstance(reason, list):
966+
return ' and '.join('subfields "{}" conflict because {}'.format(reason_name, cls.reason_message(sub_reason))
967+
for reason_name, sub_reason in reason)
968+
969+
return reason

0 commit comments

Comments
 (0)