|
| 1 | +import itertools |
1 | 2 | from ..utils import type_from_ast, is_valid_literal_value
|
| 3 | +from .utils import PairSet, DefaultOrderedDict |
2 | 4 | from ..error import GraphQLError
|
3 |
| -from ..type.definition import is_composite_type, is_input_type, is_leaf_type, GraphQLNonNull |
| 5 | +from ..type.definition import is_composite_type, is_input_type, is_leaf_type, get_named_type, GraphQLNonNull, \ |
| 6 | + GraphQLObjectType, GraphQLInterfaceType |
4 | 7 | from ..language import ast
|
5 | 8 | from ..language.visitor import Visitor, visit
|
6 | 9 | from ..language.printer import print_ast
|
@@ -627,4 +630,228 @@ class VariablesInAllowedPosition(ValidationRule):
|
627 | 630 |
|
628 | 631 |
|
629 | 632 | class OverlappingFieldsCanBeMerged(ValidationRule):
|
630 |
| - pass |
| 633 | + def __init__(self, context): |
| 634 | + super(OverlappingFieldsCanBeMerged, self).__init__(context) |
| 635 | + self.compared_set = PairSet() |
| 636 | + |
| 637 | + def find_conflicts(self, field_map): |
| 638 | + conflicts = [] |
| 639 | + for response_name, fields in field_map.items(): |
| 640 | + field_len = len(fields) |
| 641 | + if field_len <= 1: |
| 642 | + continue |
| 643 | + |
| 644 | + for field_a in fields: |
| 645 | + for field_b in fields: |
| 646 | + conflict = self.find_conflict(response_name, field_a, field_b) |
| 647 | + if conflict: |
| 648 | + conflicts.append(conflict) |
| 649 | + |
| 650 | + return conflicts |
| 651 | + |
| 652 | + @staticmethod |
| 653 | + def ast_to_hashable(ast): |
| 654 | + """ |
| 655 | + This function will take an AST, and return a portion of it that is unique enough to identify the AST, |
| 656 | + but without the unhashable bits. |
| 657 | + """ |
| 658 | + if not ast: |
| 659 | + return None |
| 660 | + |
| 661 | + return ast.__class__, ast.loc['start'], ast.loc['end'] |
| 662 | + |
| 663 | + def find_conflict(self, response_name, pair1, pair2): |
| 664 | + ast1, def1 = pair1 |
| 665 | + ast2, def2 = pair2 |
| 666 | + |
| 667 | + ast1_hashable = self.ast_to_hashable(ast1) |
| 668 | + ast2_hashable = self.ast_to_hashable(ast2) |
| 669 | + |
| 670 | + if ast1 is ast2 or self.compared_set.has(ast1_hashable, ast2_hashable): |
| 671 | + return |
| 672 | + |
| 673 | + self.compared_set.add(ast1_hashable, ast2_hashable) |
| 674 | + |
| 675 | + name1 = ast1.name.value |
| 676 | + name2 = ast2.name.value |
| 677 | + |
| 678 | + if name1 != name2: |
| 679 | + return ( |
| 680 | + (response_name, '{} and {} are different fields'.format(name1, name2)), |
| 681 | + (ast1, ast2) |
| 682 | + ) |
| 683 | + |
| 684 | + type1 = def1 and def1.type |
| 685 | + type2 = def2 and def2.type |
| 686 | + |
| 687 | + if type1 and type2 and not self.same_type(type1, type2): |
| 688 | + return ( |
| 689 | + (response_name, 'they return differing types {} and {}'.format(type1, type2)), |
| 690 | + (ast1, ast2) |
| 691 | + ) |
| 692 | + |
| 693 | + if not self.same_arguments(ast1.arguments, ast2.arguments): |
| 694 | + return ( |
| 695 | + (response_name, 'they have differing arguments'), |
| 696 | + (ast1, ast2) |
| 697 | + ) |
| 698 | + |
| 699 | + if not self.same_directives(ast1.directives, ast2.directives): |
| 700 | + return ( |
| 701 | + (response_name, 'they have differing directives'), |
| 702 | + (ast1, ast2) |
| 703 | + ) |
| 704 | + |
| 705 | + selection_set1 = ast1.selection_set |
| 706 | + selection_set2 = ast2.selection_set |
| 707 | + |
| 708 | + if selection_set1 and selection_set2: |
| 709 | + visited_fragment_names = set() |
| 710 | + |
| 711 | + subfield_map = self.collect_field_asts_and_defs( |
| 712 | + get_named_type(type1), |
| 713 | + selection_set1, |
| 714 | + visited_fragment_names |
| 715 | + ) |
| 716 | + |
| 717 | + subfield_map = self.collect_field_asts_and_defs( |
| 718 | + get_named_type(type2), |
| 719 | + selection_set2, |
| 720 | + visited_fragment_names, |
| 721 | + subfield_map |
| 722 | + ) |
| 723 | + |
| 724 | + conflicts = self.find_conflicts(subfield_map) |
| 725 | + if conflicts: |
| 726 | + return ( |
| 727 | + (response_name, [conflict[0] for conflict in conflicts]), |
| 728 | + tuple(itertools.chain((ast1, ast2), *[conflict[1] for conflict in conflicts])) |
| 729 | + ) |
| 730 | + |
| 731 | + def leave_SelectionSet(self, node, *args): |
| 732 | + field_map = self.collect_field_asts_and_defs( |
| 733 | + self.context.get_parent_type(), |
| 734 | + node |
| 735 | + ) |
| 736 | + |
| 737 | + conflicts = self.find_conflicts(field_map) |
| 738 | + if conflicts: |
| 739 | + return [ |
| 740 | + GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields)) for |
| 741 | + (reason_name, reason), fields in conflicts |
| 742 | + ] |
| 743 | + |
| 744 | + @staticmethod |
| 745 | + def same_type(type1, type2): |
| 746 | + return type1.is_same_type(type2) |
| 747 | + |
| 748 | + @staticmethod |
| 749 | + def same_value(value1, value2): |
| 750 | + return (not value1 and not value2) or print_ast(value1) == print_ast(value2) |
| 751 | + |
| 752 | + @classmethod |
| 753 | + def same_arguments(cls, arguments1, arguments2): |
| 754 | + # Check to see if they are empty arguments or nones. If they are, we can |
| 755 | + # bail out early. |
| 756 | + if not (arguments1 and arguments2): |
| 757 | + return True |
| 758 | + |
| 759 | + if len(arguments1) != len(arguments2): |
| 760 | + return False |
| 761 | + |
| 762 | + arguments2_values_to_arg = {a.name.value: a for a in arguments2} |
| 763 | + |
| 764 | + for argument1 in arguments1: |
| 765 | + argument2 = arguments2_values_to_arg.get(argument1.name.value) |
| 766 | + if not argument2: |
| 767 | + return False |
| 768 | + |
| 769 | + if not cls.same_value(argument1.value, argument2.value): |
| 770 | + return False |
| 771 | + |
| 772 | + return True |
| 773 | + |
| 774 | + @classmethod |
| 775 | + def same_directives(cls, directives1, directives2): |
| 776 | + # Check to see if they are empty directives or nones. If they are, we can |
| 777 | + # bail out early. |
| 778 | + if not (directives1 and directives2): |
| 779 | + return True |
| 780 | + |
| 781 | + if len(directives1) != len(directives2): |
| 782 | + return False |
| 783 | + |
| 784 | + directives2_values_to_arg = {a.name.value: a for a in directives2} |
| 785 | + |
| 786 | + for directive1 in directives1: |
| 787 | + directive2 = directives2_values_to_arg.get(directive1.name.value) |
| 788 | + if not directive2: |
| 789 | + return False |
| 790 | + |
| 791 | + if not cls.same_arguments(directive1.arguments, directive2.arguments): |
| 792 | + return False |
| 793 | + |
| 794 | + return True |
| 795 | + |
| 796 | + def collect_field_asts_and_defs(self, parent_type, selection_set, visited_fragment_names=None, ast_and_defs=None): |
| 797 | + if visited_fragment_names is None: |
| 798 | + visited_fragment_names = set() |
| 799 | + |
| 800 | + if ast_and_defs is None: |
| 801 | + # An ordered dictionary is required, otherwise the error message will be out of order. |
| 802 | + # We need to preserve the order that the item was inserted into the dict, as that will dictate |
| 803 | + # in which order the reasons in the error message should show. |
| 804 | + # Otherwise, the error messages will be inconsistently ordered for the same AST. |
| 805 | + # And this can make it so that tests fail half the time, and fool a user into thinking that |
| 806 | + # the errors are different, when in-fact they are the same, just that the ordering of the reasons differ. |
| 807 | + ast_and_defs = DefaultOrderedDict(list) |
| 808 | + |
| 809 | + for selection in selection_set.selections: |
| 810 | + if isinstance(selection, ast.Field): |
| 811 | + field_name = selection.name.value |
| 812 | + field_def = None |
| 813 | + if isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)): |
| 814 | + field_def = parent_type.get_fields().get(field_name) |
| 815 | + |
| 816 | + response_name = selection.alias.value if selection.alias else field_name |
| 817 | + ast_and_defs[response_name].append((selection, field_def)) |
| 818 | + |
| 819 | + elif isinstance(selection, ast.InlineFragment): |
| 820 | + self.collect_field_asts_and_defs( |
| 821 | + type_from_ast(self.context.get_schema(), selection.type_condition), |
| 822 | + selection.selection_set, |
| 823 | + visited_fragment_names, |
| 824 | + ast_and_defs |
| 825 | + ) |
| 826 | + |
| 827 | + elif isinstance(selection, ast.FragmentSpread): |
| 828 | + fragment_name = selection.name.value |
| 829 | + if fragment_name in visited_fragment_names: |
| 830 | + continue |
| 831 | + |
| 832 | + visited_fragment_names.add(fragment_name) |
| 833 | + fragment = self.context.get_fragment(fragment_name) |
| 834 | + |
| 835 | + if not fragment: |
| 836 | + continue |
| 837 | + |
| 838 | + self.collect_field_asts_and_defs( |
| 839 | + type_from_ast(self.context.get_schema(), fragment.type_condition), |
| 840 | + fragment.selection_set, |
| 841 | + visited_fragment_names, |
| 842 | + ast_and_defs |
| 843 | + ) |
| 844 | + |
| 845 | + return ast_and_defs |
| 846 | + |
| 847 | + @classmethod |
| 848 | + def fields_conflict_message(cls, reason_name, reason): |
| 849 | + return 'Fields "{}" conflict because {}'.format(reason_name, cls.reason_message(reason)) |
| 850 | + |
| 851 | + @classmethod |
| 852 | + def reason_message(cls, reason): |
| 853 | + if isinstance(reason, list): |
| 854 | + return ' and '.join('subfields "{}" conflict because {}'.format(reason_name, cls.reason_message(sub_reason)) |
| 855 | + for reason_name, sub_reason in reason) |
| 856 | + |
| 857 | + return reason |
0 commit comments