1
+ import itertools
1
2
from ..utils import type_from_ast , is_valid_literal_value
3
+ from .utils import PairSet , DefaultOrderedDict
2
4
from ..error import GraphQLError
3
5
from ..type .definition import (
4
6
is_composite_type ,
5
7
is_input_type ,
6
8
is_leaf_type ,
9
+ get_named_type ,
7
10
GraphQLNonNull ,
8
11
GraphQLList ,
9
12
GraphQLObjectType ,
@@ -243,7 +246,7 @@ def reduce_spread_fragments(spreads):
243
246
)
244
247
for fragment_definition in self .fragment_definitions
245
248
if fragment_definition .name .value not in fragment_names_used
246
- ]
249
+ ]
247
250
248
251
if errors :
249
252
return errors
@@ -295,11 +298,14 @@ def do_types_overlap(t1, t2):
295
298
296
299
@staticmethod
297
300
def type_incompatible_spread_message (frag_name , parent_type , frag_type ):
298
- return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}' .format (frag_name , parent_type , frag_type )
301
+ return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}' .format (frag_name ,
302
+ parent_type ,
303
+ frag_type )
299
304
300
305
@staticmethod
301
306
def type_incompatible_anon_spread_message (parent_type , frag_type ):
302
- return 'Fragment cannot be spread here as objects of type {} can never be of type {}' .format (parent_type , frag_type )
307
+ return 'Fragment cannot be spread here as objects of type {} can never be of type {}' .format (parent_type ,
308
+ frag_type )
303
309
304
310
305
311
class NoFragmentCycles (ValidationRule ):
@@ -309,7 +315,7 @@ def __init__(self, context):
309
315
node .name .value : self .gather_spreads (node )
310
316
for node in context .get_ast ().definitions
311
317
if isinstance (node , ast .FragmentDefinition )
312
- }
318
+ }
313
319
self .known_to_lead_to_cycle = set ()
314
320
315
321
def enter_FragmentDefinition (self , node , * args ):
@@ -444,7 +450,7 @@ def leave_OperationDefinition(self, *args):
444
450
)
445
451
for variable_definition in self .variable_definitions
446
452
if variable_definition .variable .name .value not in self .variable_name_used
447
- ]
453
+ ]
448
454
449
455
if errors :
450
456
return errors
@@ -731,8 +737,233 @@ def var_type_allowed_for_type(cls, var_type, expected_type):
731
737
732
738
@staticmethod
733
739
def bad_var_pos_message (var_name , var_type , expected_type ):
734
- return 'Variable "{}" of type "{}" used in position expecting type "{}".' .format (var_name , var_type , expected_type )
740
+ return 'Variable "{}" of type "{}" used in position expecting type "{}".' .format (var_name , var_type ,
741
+ expected_type )
735
742
736
743
737
744
class OverlappingFieldsCanBeMerged (ValidationRule ):
738
- pass
745
+ def __init__ (self , context ):
746
+ super (OverlappingFieldsCanBeMerged , self ).__init__ (context )
747
+ self .compared_set = PairSet ()
748
+
749
+ def find_conflicts (self , field_map ):
750
+ conflicts = []
751
+ for response_name , fields in field_map .items ():
752
+ field_len = len (fields )
753
+ if field_len <= 1 :
754
+ continue
755
+
756
+ for field_a in fields :
757
+ for field_b in fields :
758
+ conflict = self .find_conflict (response_name , field_a , field_b )
759
+ if conflict :
760
+ conflicts .append (conflict )
761
+
762
+ return conflicts
763
+
764
+ @staticmethod
765
+ def ast_to_hashable (ast ):
766
+ """
767
+ This function will take an AST, and return a portion of it that is unique enough to identify the AST,
768
+ but without the unhashable bits.
769
+ """
770
+ if not ast :
771
+ return None
772
+
773
+ return ast .__class__ , ast .loc ['start' ], ast .loc ['end' ]
774
+
775
+ def find_conflict (self , response_name , pair1 , pair2 ):
776
+ ast1 , def1 = pair1
777
+ ast2 , def2 = pair2
778
+
779
+ ast1_hashable = self .ast_to_hashable (ast1 )
780
+ ast2_hashable = self .ast_to_hashable (ast2 )
781
+
782
+ if ast1 is ast2 or self .compared_set .has (ast1_hashable , ast2_hashable ):
783
+ return
784
+
785
+ self .compared_set .add (ast1_hashable , ast2_hashable )
786
+
787
+ name1 = ast1 .name .value
788
+ name2 = ast2 .name .value
789
+
790
+ if name1 != name2 :
791
+ return (
792
+ (response_name , '{} and {} are different fields' .format (name1 , name2 )),
793
+ (ast1 , ast2 )
794
+ )
795
+
796
+ type1 = def1 and def1 .type
797
+ type2 = def2 and def2 .type
798
+
799
+ if type1 and type2 and not self .same_type (type1 , type2 ):
800
+ return (
801
+ (response_name , 'they return differing types {} and {}' .format (type1 , type2 )),
802
+ (ast1 , ast2 )
803
+ )
804
+
805
+ if not self .same_arguments (ast1 .arguments , ast2 .arguments ):
806
+ return (
807
+ (response_name , 'they have differing arguments' ),
808
+ (ast1 , ast2 )
809
+ )
810
+
811
+ if not self .same_directives (ast1 .directives , ast2 .directives ):
812
+ return (
813
+ (response_name , 'they have differing directives' ),
814
+ (ast1 , ast2 )
815
+ )
816
+
817
+ selection_set1 = ast1 .selection_set
818
+ selection_set2 = ast2 .selection_set
819
+
820
+ if selection_set1 and selection_set2 :
821
+ visited_fragment_names = set ()
822
+
823
+ subfield_map = self .collect_field_asts_and_defs (
824
+ get_named_type (type1 ),
825
+ selection_set1 ,
826
+ visited_fragment_names
827
+ )
828
+
829
+ subfield_map = self .collect_field_asts_and_defs (
830
+ get_named_type (type2 ),
831
+ selection_set2 ,
832
+ visited_fragment_names ,
833
+ subfield_map
834
+ )
835
+
836
+ conflicts = self .find_conflicts (subfield_map )
837
+ if conflicts :
838
+ return (
839
+ (response_name , [conflict [0 ] for conflict in conflicts ]),
840
+ tuple (itertools .chain ((ast1 , ast2 ), * [conflict [1 ] for conflict in conflicts ]))
841
+ )
842
+
843
+ def leave_SelectionSet (self , node , * args ):
844
+ field_map = self .collect_field_asts_and_defs (
845
+ self .context .get_parent_type (),
846
+ node
847
+ )
848
+
849
+ conflicts = self .find_conflicts (field_map )
850
+ if conflicts :
851
+ return [
852
+ GraphQLError (self .fields_conflict_message (reason_name , reason ), list (fields )) for
853
+ (reason_name , reason ), fields in conflicts
854
+ ]
855
+
856
+ @staticmethod
857
+ def same_type (type1 , type2 ):
858
+ return type1 .is_same_type (type2 )
859
+
860
+ @staticmethod
861
+ def same_value (value1 , value2 ):
862
+ return (not value1 and not value2 ) or print_ast (value1 ) == print_ast (value2 )
863
+
864
+ @classmethod
865
+ def same_arguments (cls , arguments1 , arguments2 ):
866
+ # Check to see if they are empty arguments or nones. If they are, we can
867
+ # bail out early.
868
+ if not (arguments1 or arguments2 ):
869
+ return True
870
+
871
+ if len (arguments1 ) != len (arguments2 ):
872
+ return False
873
+
874
+ arguments2_values_to_arg = {a .name .value : a for a in arguments2 }
875
+
876
+ for argument1 in arguments1 :
877
+ argument2 = arguments2_values_to_arg .get (argument1 .name .value )
878
+ if not argument2 :
879
+ return False
880
+
881
+ if not cls .same_value (argument1 .value , argument2 .value ):
882
+ return False
883
+
884
+ return True
885
+
886
+ @classmethod
887
+ def same_directives (cls , directives1 , directives2 ):
888
+ # Check to see if they are empty directives or nones. If they are, we can
889
+ # bail out early.
890
+ if not (directives1 or directives2 ):
891
+ return True
892
+
893
+ if len (directives1 ) != len (directives2 ):
894
+ return False
895
+
896
+ directives2_values_to_arg = {a .name .value : a for a in directives2 }
897
+
898
+ for directive1 in directives1 :
899
+ directive2 = directives2_values_to_arg .get (directive1 .name .value )
900
+ if not directive2 :
901
+ return False
902
+
903
+ if not cls .same_arguments (directive1 .arguments , directive2 .arguments ):
904
+ return False
905
+
906
+ return True
907
+
908
+ def collect_field_asts_and_defs (self , parent_type , selection_set , visited_fragment_names = None , ast_and_defs = None ):
909
+ if visited_fragment_names is None :
910
+ visited_fragment_names = set ()
911
+
912
+ if ast_and_defs is None :
913
+ # An ordered dictionary is required, otherwise the error message will be out of order.
914
+ # We need to preserve the order that the item was inserted into the dict, as that will dictate
915
+ # in which order the reasons in the error message should show.
916
+ # Otherwise, the error messages will be inconsistently ordered for the same AST.
917
+ # And this can make it so that tests fail half the time, and fool a user into thinking that
918
+ # the errors are different, when in-fact they are the same, just that the ordering of the reasons differ.
919
+ ast_and_defs = DefaultOrderedDict (list )
920
+
921
+ for selection in selection_set .selections :
922
+ if isinstance (selection , ast .Field ):
923
+ field_name = selection .name .value
924
+ field_def = None
925
+ if isinstance (parent_type , (GraphQLObjectType , GraphQLInterfaceType )):
926
+ field_def = parent_type .get_fields ().get (field_name )
927
+
928
+ response_name = selection .alias .value if selection .alias else field_name
929
+ ast_and_defs [response_name ].append ((selection , field_def ))
930
+
931
+ elif isinstance (selection , ast .InlineFragment ):
932
+ self .collect_field_asts_and_defs (
933
+ type_from_ast (self .context .get_schema (), selection .type_condition ),
934
+ selection .selection_set ,
935
+ visited_fragment_names ,
936
+ ast_and_defs
937
+ )
938
+
939
+ elif isinstance (selection , ast .FragmentSpread ):
940
+ fragment_name = selection .name .value
941
+ if fragment_name in visited_fragment_names :
942
+ continue
943
+
944
+ visited_fragment_names .add (fragment_name )
945
+ fragment = self .context .get_fragment (fragment_name )
946
+
947
+ if not fragment :
948
+ continue
949
+
950
+ self .collect_field_asts_and_defs (
951
+ type_from_ast (self .context .get_schema (), fragment .type_condition ),
952
+ fragment .selection_set ,
953
+ visited_fragment_names ,
954
+ ast_and_defs
955
+ )
956
+
957
+ return ast_and_defs
958
+
959
+ @classmethod
960
+ def fields_conflict_message (cls , reason_name , reason ):
961
+ return 'Fields "{}" conflict because {}' .format (reason_name , cls .reason_message (reason ))
962
+
963
+ @classmethod
964
+ def reason_message (cls , reason ):
965
+ if isinstance (reason , list ):
966
+ return ' and ' .join ('subfields "{}" conflict because {}' .format (reason_name , cls .reason_message (sub_reason ))
967
+ for reason_name , sub_reason in reason )
968
+
969
+ return reason
0 commit comments