@@ -827,16 +827,209 @@ private fun ValidationScope.validateScalars() {
827827 }
828828}
829829
830+ internal class FieldAndNode (val field : GQLInputValueDefinition , val node : Node )
831+
832+ internal class Node (val typeDefinition : GQLInputObjectTypeDefinition ) {
833+ val isOneOf = typeDefinition.directives.findOneOf()
834+
835+ /* *
836+ * Whether that node is valid (can reach a leaf node)
837+ * This will be updated as we traverse the graph
838+ */
839+ var isValid = false
840+
841+ /* *
842+ * Whether that node is visited
843+ */
844+ var visited = false
845+ var edgeCount = 0
846+ val predecessors = mutableSetOf<Node >()
847+ val sucessors = mutableSetOf<FieldAndNode >()
848+
849+ /* *
850+ * tarjan
851+ */
852+ var index: Int? = null
853+ var lowLink: Int? = null
854+ var onStack = false
855+
856+ override fun toString () = typeDefinition.name
857+ }
858+
859+ private fun ValidationScope.reverseGraph (inputObjectTypeDefinitions : List <GQLInputObjectTypeDefinition >): MutableCollection <Node > {
860+ val nodes = mutableMapOf<String , Node >()
861+ inputObjectTypeDefinitions.forEach {
862+ nodes.put(it.name, Node (it))
863+ }
864+
865+ nodes.values.forEach { node ->
866+ /* *
867+ * Track the leaf fields.
868+ * - `@oneOf` are not valid by default but may become if they have one escape hatch.
869+ * - other types are valid by default but may become invalid if they have one non-null reference
870+ */
871+ node.isValid = ! node.isOneOf
872+ node.typeDefinition.inputFields.forEach { field ->
873+ val fieldType = field.type
874+ if (node.isOneOf) {
875+ if (fieldType is GQLNamedType ) {
876+ val fieldTypeDefinition = typeDefinitions.get(fieldType.name)
877+ if (fieldTypeDefinition is GQLInputObjectTypeDefinition ) {
878+ val successor = nodes.get(fieldTypeDefinition.name)!!
879+ successor.predecessors.add(node)
880+ node.sucessors.add(FieldAndNode (field, successor))
881+ } else {
882+ // scalar or enum
883+ node.isValid = true
884+ }
885+ } else {
886+ // Maybe a list
887+ // Should not be non-null. If it is, other validation rules will catch it.
888+ node.isValid = true
889+ }
890+ } else {
891+ if (fieldType is GQLNonNullType ) {
892+ val innerType = fieldType.type
893+ if (innerType is GQLNamedType ) {
894+ val fieldTypeDefinition = typeDefinitions.get(innerType.name)
895+ if (fieldTypeDefinition is GQLInputObjectTypeDefinition ) {
896+ // Not a leaf field
897+ node.isValid = false
898+ node.edgeCount++
899+ val successor = nodes.get(fieldTypeDefinition.name)!!
900+ successor.predecessors.add(node)
901+ node.sucessors.add(FieldAndNode (field, successor))
902+ }
903+ } else {
904+ // List type => escape
905+ }
906+ } else {
907+ // Nullable type => escape
908+ }
909+ }
910+ }
911+ }
912+ return nodes.values
913+ }
914+
915+ /* *
916+ * walks the reverse graph, starting with the leaf nodes to find all the valid nodes
917+ */
918+ private fun findValid (nodes : Collection <Node >) {
919+ val stack = ArrayDeque <Node >()
920+ // Start with the leaf, non-oneOf types
921+ stack.addAll(nodes.filter { it.isValid })
922+
923+ while (stack.isNotEmpty()) {
924+ val node = stack.removeFirst()
925+ if (node.visited) continue
926+ node.visited = true
927+ node.predecessors.forEach { predecessor ->
928+ if (predecessor.isOneOf) {
929+ predecessor.isValid = true
930+ stack.addAll(predecessor.predecessors)
931+ } else {
932+ predecessor.edgeCount--
933+ if (predecessor.edgeCount == 0 ) {
934+ predecessor.isValid = true
935+ stack.addAll(predecessor.predecessors)
936+ }
937+ }
938+ }
939+ }
940+ }
941+
942+ private fun removeValid (nodes : MutableCollection <Node >) {
943+ nodes.removeAll { it.isValid }
944+ // At this point, there shouldn't be any edge pointing to a valid node or that would be an escape input field
945+ // nodes.forEach {
946+ // check(it.sucessors.none { it.node.isValid })
947+ // }
948+ }
949+
950+ internal class PathElement (val typename : String , val inputField : GQLInputValueDefinition )
951+ internal typealias Scc = Collection <Node >
952+
953+ /* *
954+ * For error reporting purposes, find the longest cycle inside the SCC
955+ */
956+ private fun findWitnessCycle (scc : Collection <Node >): List <PathElement > {
957+ val start = scc.first()
958+
959+ val path = mutableListOf<PathElement >()
960+ val visited = mutableSetOf<Node >()
961+
962+ fun dfs (current : Node ): Boolean {
963+ visited.add(current)
964+ for (fieldAndNode in current.sucessors) {
965+ val next = fieldAndNode.node
966+ path.add(PathElement (current.typeDefinition.name, fieldAndNode.field))
967+ if (next == start) return true
968+ if (next !in visited && dfs(next)) return true
969+ path.removeAt(path.lastIndex)
970+ }
971+ visited.remove(current)
972+ return false
973+ }
974+
975+ dfs(start)
976+ return path
977+ }
978+
979+ internal fun tarjanScc (nodes : Collection <Node >): Collection <Scc > {
980+ var index = 0
981+ val stack = ArrayDeque <Node >()
982+ val result = mutableListOf<Scc >()
983+
984+ fun strongConnect (v : Node ) {
985+ v.index = index
986+ v.lowLink = index
987+ index++
988+ stack.addLast(v)
989+ v.onStack = true
990+
991+ v.sucessors.forEach {
992+ val w = it.node
993+ if (w.index == null ) {
994+ strongConnect(w)
995+ v.lowLink = minOf(v.lowLink!! , w.lowLink!! )
996+ } else if (w.onStack) {
997+ v.lowLink = minOf(v.lowLink!! , w.index!! )
998+ }
999+ }
1000+
1001+ if (v.lowLink == v.index) {
1002+ val scc = mutableListOf<Node >()
1003+ while (true ) {
1004+ val w = stack.removeLast()
1005+ w.onStack = false
1006+ scc.add(w)
1007+ if (w == v) break
1008+ }
1009+ result.add(scc)
1010+ }
1011+ }
1012+
1013+ nodes.forEach {
1014+ if (it.index == null ) {
1015+ strongConnect(it)
1016+ }
1017+ }
1018+
1019+ return result
1020+ }
1021+
8301022private fun ValidationScope.validateInputObjects () {
831- val traversalState = TraversalState ()
1023+ val inputObjects = typeDefinitions.values.filterIsInstance<GQLInputObjectTypeDefinition >()
1024+ validateInputObjectsCycles(inputObjects)
1025+
8321026 val defaultValueTraversalState = DefaultValueTraversalState ()
833- typeDefinitions.values.filterIsInstance< GQLInputObjectTypeDefinition >() .forEach { o ->
1027+ inputObjects .forEach { o ->
8341028 if (o.inputFields.isEmpty()) {
8351029 registerIssue(" Input object must specify one or more input fields" , o.sourceLocation)
8361030 }
8371031
8381032 validateDirectivesInConstContext(o.directives, o)
839- validateInputFieldCycles(o, traversalState)
8401033 validateInputObjectDefaultValue(o, defaultValueTraversalState)
8411034
8421035 val isOneOfInputObject = o.directives.findOneOf()
@@ -853,65 +1046,46 @@ private fun ValidationScope.validateInputObjects() {
8531046 }
8541047}
8551048
856- private class TraversalState {
857- val visitedTypes = mutableSetOf<String >()
858- val fieldPath = mutableListOf<Pair <String , SourceLocation ?>>()
859- val fieldPathIndexByTypeName = mutableMapOf<String , Int >()
860- }
861-
862- private class DefaultValueTraversalState {
863- val visitedFields = mutableSetOf<String >()
864- val fieldPath = mutableListOf<Pair <String , SourceLocation ?>>()
865- val fieldPathIndex = mutableMapOf<String , Int >()
866- }
867-
868-
869- private fun ValidationScope.validateInputFieldCycles (inputObjectTypeDefinition : GQLInputObjectTypeDefinition , state : TraversalState ) {
870- if (state.visitedTypes.contains(inputObjectTypeDefinition.name)) {
871- return
872- }
873- state.visitedTypes.add(inputObjectTypeDefinition.name)
874-
875- state.fieldPathIndexByTypeName[inputObjectTypeDefinition.name] = state.fieldPath.size
876-
877- inputObjectTypeDefinition.inputFields.forEach {
878- val type = it.type
879- if (type is GQLNonNullType && type.type is GQLNamedType ) {
880- val fieldType = typeDefinitions.get(type.type.name)
881- if (fieldType is GQLInputObjectTypeDefinition ) {
882- val cycleIndex = state.fieldPathIndexByTypeName.get(fieldType.name)
883-
884- state.fieldPath.add(" ${fieldType.name} .${it.name} " to it.sourceLocation)
885-
886- if (cycleIndex == null ) {
887- validateInputFieldCycles(fieldType, state)
888- } else {
889- val cyclePath = state.fieldPath.subList(cycleIndex, state.fieldPath.size)
890-
891- cyclePath.forEach {
892- issues.add(
893- OtherValidationIssue (
894- buildString {
895- append(" Invalid circular reference. The Input Object '${fieldType.name} ' references itself " )
896- if (cyclePath.size > 1 ) {
897- append(" via the non-null fields: " )
898- } else {
899- append(" in the non-null field " )
900- }
901- append(cyclePath.map { it.first }.joinToString(" , " ))
902- },
903- it.second
904- )
905- )
1049+ private fun ValidationScope.validateInputObjectsCycles (inputObjectTypeDefinitions : List <GQLInputObjectTypeDefinition >) {
1050+ val nodes = reverseGraph(inputObjectTypeDefinitions)
1051+ findValid(nodes)
1052+ removeValid(nodes)
1053+ tarjanScc(nodes).forEach { scc ->
1054+ if (scc.size == 1 ) {
1055+ val firstNode = scc.first()
1056+ val fieldAndNode = firstNode.sucessors.firstOrNull()
1057+ if (fieldAndNode != null && fieldAndNode.node.typeDefinition.name == firstNode.typeDefinition.name) {
1058+ registerIssue(" Input object `${firstNode.typeDefinition.name} ` references itself through field `${firstNode.typeDefinition.name} .${fieldAndNode.field.name} ` and cannot be constructed." , fieldAndNode.field.sourceLocation)
1059+ } else {
1060+ // Trivial SCC containing a single, non self-referncing node are not an issue.
1061+ }
1062+ } else {
1063+ val cycle = findWitnessCycle(scc)
1064+ cycle.indices.forEach { i ->
1065+ val start = cycle.get(i)
1066+ val cycleAsString = buildString {
1067+ var j = i
1068+ repeat(cycle.size) {
1069+ val cur = cycle.get(j)
1070+ append(" ${cur.typename} .${cur.inputField.name} --> " )
1071+ j++
1072+ if (j == cycle.size) {
1073+ j = 0
1074+ }
9061075 }
1076+ append(start.typename)
9071077 }
908-
909- state.fieldPath.removeLast()
1078+ registerIssue(" Input object `${start.typename} ` references itself through an unbreakable chain of input fields and cannot be constructed: $cycleAsString " , start.inputField.sourceLocation)
9101079 }
9111080 }
9121081 }
1082+ }
9131083
914- state.fieldPathIndexByTypeName.remove(inputObjectTypeDefinition.name)
1084+
1085+ private class DefaultValueTraversalState {
1086+ val visitedFields = mutableSetOf<String >()
1087+ val fieldPath = mutableListOf<Pair <String , SourceLocation ?>>()
1088+ val fieldPathIndex = mutableMapOf<String , Int >()
9151089}
9161090
9171091private fun ValidationScope.validateInputObjectDefaultValue (
0 commit comments