Skip to content

Commit 2c072e0

Browse files
authored
Fix validation of @oneOf input objects. They may be invalid if they create cycles (#6894)
See graphql/graphql-spec#1211
1 parent ed689ee commit 2c072e0

30 files changed

+453
-93
lines changed

libraries/apollo-ast/src/commonMain/kotlin/com/apollographql/apollo/ast/internal/SchemaValidationScope.kt

Lines changed: 230 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -827,16 +827,209 @@ private fun ValidationScope.validateScalars() {
827827
}
828828
}
829829

830+
internal class FieldAndNode(val field: GQLInputValueDefinition, val node: Node)
831+
832+
internal class Node(val typeDefinition: GQLInputObjectTypeDefinition) {
833+
val isOneOf = typeDefinition.directives.findOneOf()
834+
835+
/**
836+
* Whether that node is valid (can reach a leaf node)
837+
* This will be updated as we traverse the graph
838+
*/
839+
var isValid = false
840+
841+
/**
842+
* Whether that node is visited
843+
*/
844+
var visited = false
845+
var edgeCount = 0
846+
val predecessors = mutableSetOf<Node>()
847+
val sucessors = mutableSetOf<FieldAndNode>()
848+
849+
/**
850+
* tarjan
851+
*/
852+
var index: Int? = null
853+
var lowLink: Int? = null
854+
var onStack = false
855+
856+
override fun toString() = typeDefinition.name
857+
}
858+
859+
private fun ValidationScope.reverseGraph(inputObjectTypeDefinitions: List<GQLInputObjectTypeDefinition>): MutableCollection<Node> {
860+
val nodes = mutableMapOf<String, Node>()
861+
inputObjectTypeDefinitions.forEach {
862+
nodes.put(it.name, Node(it))
863+
}
864+
865+
nodes.values.forEach { node ->
866+
/**
867+
* Track the leaf fields.
868+
* - `@oneOf` are not valid by default but may become if they have one escape hatch.
869+
* - other types are valid by default but may become invalid if they have one non-null reference
870+
*/
871+
node.isValid = !node.isOneOf
872+
node.typeDefinition.inputFields.forEach { field ->
873+
val fieldType = field.type
874+
if (node.isOneOf) {
875+
if (fieldType is GQLNamedType) {
876+
val fieldTypeDefinition = typeDefinitions.get(fieldType.name)
877+
if (fieldTypeDefinition is GQLInputObjectTypeDefinition) {
878+
val successor = nodes.get(fieldTypeDefinition.name)!!
879+
successor.predecessors.add(node)
880+
node.sucessors.add(FieldAndNode(field, successor))
881+
} else {
882+
// scalar or enum
883+
node.isValid = true
884+
}
885+
} else {
886+
// Maybe a list
887+
// Should not be non-null. If it is, other validation rules will catch it.
888+
node.isValid = true
889+
}
890+
} else {
891+
if (fieldType is GQLNonNullType) {
892+
val innerType = fieldType.type
893+
if (innerType is GQLNamedType) {
894+
val fieldTypeDefinition = typeDefinitions.get(innerType.name)
895+
if (fieldTypeDefinition is GQLInputObjectTypeDefinition) {
896+
// Not a leaf field
897+
node.isValid = false
898+
node.edgeCount++
899+
val successor = nodes.get(fieldTypeDefinition.name)!!
900+
successor.predecessors.add(node)
901+
node.sucessors.add(FieldAndNode(field, successor))
902+
}
903+
} else {
904+
// List type => escape
905+
}
906+
} else {
907+
// Nullable type => escape
908+
}
909+
}
910+
}
911+
}
912+
return nodes.values
913+
}
914+
915+
/**
916+
* walks the reverse graph, starting with the leaf nodes to find all the valid nodes
917+
*/
918+
private fun findValid(nodes: Collection<Node>) {
919+
val stack = ArrayDeque<Node>()
920+
// Start with the leaf, non-oneOf types
921+
stack.addAll(nodes.filter { it.isValid })
922+
923+
while (stack.isNotEmpty()) {
924+
val node = stack.removeFirst()
925+
if (node.visited) continue
926+
node.visited = true
927+
node.predecessors.forEach { predecessor ->
928+
if (predecessor.isOneOf) {
929+
predecessor.isValid = true
930+
stack.addAll(predecessor.predecessors)
931+
} else {
932+
predecessor.edgeCount--
933+
if (predecessor.edgeCount == 0) {
934+
predecessor.isValid = true
935+
stack.addAll(predecessor.predecessors)
936+
}
937+
}
938+
}
939+
}
940+
}
941+
942+
private fun removeValid(nodes: MutableCollection<Node>) {
943+
nodes.removeAll { it.isValid }
944+
// At this point, there shouldn't be any edge pointing to a valid node or that would be an escape input field
945+
// nodes.forEach {
946+
// check(it.sucessors.none { it.node.isValid })
947+
// }
948+
}
949+
950+
internal class PathElement(val typename: String, val inputField: GQLInputValueDefinition)
951+
internal typealias Scc = Collection<Node>
952+
953+
/**
954+
* For error reporting purposes, find the longest cycle inside the SCC
955+
*/
956+
private fun findWitnessCycle(scc: Collection<Node>): List<PathElement> {
957+
val start = scc.first()
958+
959+
val path = mutableListOf<PathElement>()
960+
val visited = mutableSetOf<Node>()
961+
962+
fun dfs(current: Node): Boolean {
963+
visited.add(current)
964+
for (fieldAndNode in current.sucessors) {
965+
val next = fieldAndNode.node
966+
path.add(PathElement(current.typeDefinition.name, fieldAndNode.field))
967+
if (next == start) return true
968+
if (next !in visited && dfs(next)) return true
969+
path.removeAt(path.lastIndex)
970+
}
971+
visited.remove(current)
972+
return false
973+
}
974+
975+
dfs(start)
976+
return path
977+
}
978+
979+
internal fun tarjanScc(nodes: Collection<Node>): Collection<Scc> {
980+
var index = 0
981+
val stack = ArrayDeque<Node>()
982+
val result = mutableListOf<Scc>()
983+
984+
fun strongConnect(v: Node) {
985+
v.index = index
986+
v.lowLink = index
987+
index++
988+
stack.addLast(v)
989+
v.onStack = true
990+
991+
v.sucessors.forEach {
992+
val w = it.node
993+
if (w.index == null) {
994+
strongConnect(w)
995+
v.lowLink = minOf(v.lowLink!!, w.lowLink!!)
996+
} else if (w.onStack) {
997+
v.lowLink = minOf(v.lowLink!!, w.index!!)
998+
}
999+
}
1000+
1001+
if (v.lowLink == v.index) {
1002+
val scc = mutableListOf<Node>()
1003+
while (true) {
1004+
val w = stack.removeLast()
1005+
w.onStack = false
1006+
scc.add(w)
1007+
if (w == v) break
1008+
}
1009+
result.add(scc)
1010+
}
1011+
}
1012+
1013+
nodes.forEach {
1014+
if (it.index == null) {
1015+
strongConnect(it)
1016+
}
1017+
}
1018+
1019+
return result
1020+
}
1021+
8301022
private fun ValidationScope.validateInputObjects() {
831-
val traversalState = TraversalState()
1023+
val inputObjects = typeDefinitions.values.filterIsInstance<GQLInputObjectTypeDefinition>()
1024+
validateInputObjectsCycles(inputObjects)
1025+
8321026
val defaultValueTraversalState = DefaultValueTraversalState()
833-
typeDefinitions.values.filterIsInstance<GQLInputObjectTypeDefinition>().forEach { o ->
1027+
inputObjects.forEach { o ->
8341028
if (o.inputFields.isEmpty()) {
8351029
registerIssue("Input object must specify one or more input fields", o.sourceLocation)
8361030
}
8371031

8381032
validateDirectivesInConstContext(o.directives, o)
839-
validateInputFieldCycles(o, traversalState)
8401033
validateInputObjectDefaultValue(o, defaultValueTraversalState)
8411034

8421035
val isOneOfInputObject = o.directives.findOneOf()
@@ -853,65 +1046,46 @@ private fun ValidationScope.validateInputObjects() {
8531046
}
8541047
}
8551048

856-
private class TraversalState {
857-
val visitedTypes = mutableSetOf<String>()
858-
val fieldPath = mutableListOf<Pair<String, SourceLocation?>>()
859-
val fieldPathIndexByTypeName = mutableMapOf<String, Int>()
860-
}
861-
862-
private class DefaultValueTraversalState {
863-
val visitedFields = mutableSetOf<String>()
864-
val fieldPath = mutableListOf<Pair<String, SourceLocation?>>()
865-
val fieldPathIndex = mutableMapOf<String, Int>()
866-
}
867-
868-
869-
private fun ValidationScope.validateInputFieldCycles(inputObjectTypeDefinition: GQLInputObjectTypeDefinition, state: TraversalState) {
870-
if (state.visitedTypes.contains(inputObjectTypeDefinition.name)) {
871-
return
872-
}
873-
state.visitedTypes.add(inputObjectTypeDefinition.name)
874-
875-
state.fieldPathIndexByTypeName[inputObjectTypeDefinition.name] = state.fieldPath.size
876-
877-
inputObjectTypeDefinition.inputFields.forEach {
878-
val type = it.type
879-
if (type is GQLNonNullType && type.type is GQLNamedType) {
880-
val fieldType = typeDefinitions.get(type.type.name)
881-
if (fieldType is GQLInputObjectTypeDefinition) {
882-
val cycleIndex = state.fieldPathIndexByTypeName.get(fieldType.name)
883-
884-
state.fieldPath.add("${fieldType.name}.${it.name}" to it.sourceLocation)
885-
886-
if (cycleIndex == null) {
887-
validateInputFieldCycles(fieldType, state)
888-
} else {
889-
val cyclePath = state.fieldPath.subList(cycleIndex, state.fieldPath.size)
890-
891-
cyclePath.forEach {
892-
issues.add(
893-
OtherValidationIssue(
894-
buildString {
895-
append("Invalid circular reference. The Input Object '${fieldType.name}' references itself ")
896-
if (cyclePath.size > 1) {
897-
append("via the non-null fields: ")
898-
} else {
899-
append("in the non-null field ")
900-
}
901-
append(cyclePath.map { it.first }.joinToString(", "))
902-
},
903-
it.second
904-
)
905-
)
1049+
private fun ValidationScope.validateInputObjectsCycles(inputObjectTypeDefinitions: List<GQLInputObjectTypeDefinition>) {
1050+
val nodes = reverseGraph(inputObjectTypeDefinitions)
1051+
findValid(nodes)
1052+
removeValid(nodes)
1053+
tarjanScc(nodes).forEach { scc ->
1054+
if (scc.size == 1) {
1055+
val firstNode = scc.first()
1056+
val fieldAndNode = firstNode.sucessors.firstOrNull()
1057+
if (fieldAndNode != null && fieldAndNode.node.typeDefinition.name == firstNode.typeDefinition.name) {
1058+
registerIssue("Input object `${firstNode.typeDefinition.name}` references itself through field `${firstNode.typeDefinition.name}.${fieldAndNode.field.name}` and cannot be constructed.", fieldAndNode.field.sourceLocation)
1059+
} else {
1060+
// Trivial SCC containing a single, non self-referncing node are not an issue.
1061+
}
1062+
} else {
1063+
val cycle = findWitnessCycle(scc)
1064+
cycle.indices.forEach { i ->
1065+
val start = cycle.get(i)
1066+
val cycleAsString = buildString {
1067+
var j = i
1068+
repeat(cycle.size) {
1069+
val cur = cycle.get(j)
1070+
append("${cur.typename}.${cur.inputField.name} --> ")
1071+
j++
1072+
if (j == cycle.size) {
1073+
j = 0
1074+
}
9061075
}
1076+
append(start.typename)
9071077
}
908-
909-
state.fieldPath.removeLast()
1078+
registerIssue("Input object `${start.typename}` references itself through an unbreakable chain of input fields and cannot be constructed: $cycleAsString", start.inputField.sourceLocation)
9101079
}
9111080
}
9121081
}
1082+
}
9131083

914-
state.fieldPathIndexByTypeName.remove(inputObjectTypeDefinition.name)
1084+
1085+
private class DefaultValueTraversalState {
1086+
val visitedFields = mutableSetOf<String>()
1087+
val fieldPath = mutableListOf<Pair<String, SourceLocation?>>()
1088+
val fieldPathIndex = mutableMapOf<String, Int>()
9151089
}
9161090

9171091
private fun ValidationScope.validateInputObjectDefaultValue(
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.apollographql.apollo.graphql.ast.test
2+
3+
import com.apollographql.apollo.ast.GQLField
4+
import com.apollographql.apollo.ast.GQLInputObjectTypeDefinition
5+
import com.apollographql.apollo.ast.GQLInputValueDefinition
6+
import com.apollographql.apollo.ast.GQLNamedType
7+
import com.apollographql.apollo.ast.internal.FieldAndNode
8+
import com.apollographql.apollo.ast.internal.Node
9+
import com.apollographql.apollo.ast.internal.tarjanScc
10+
import kotlin.test.Test
11+
12+
class TarjanTest {
13+
fun typeDefinition(name: String) = GQLInputObjectTypeDefinition(
14+
sourceLocation = null,
15+
description = "",
16+
name = name,
17+
directives = emptyList(),
18+
inputFields = emptyList()
19+
)
20+
val field = GQLInputValueDefinition(
21+
sourceLocation = null,
22+
name = "",
23+
directives = emptyList(),
24+
description = "",
25+
type = GQLNamedType(null, ""),
26+
defaultValue = null,
27+
)
28+
29+
internal fun node(name: String) = Node(typeDefinition(name)).apply { isValid = false }
30+
31+
@Test
32+
fun test1() {
33+
val a = node("a")
34+
val b = node("b")
35+
val c = node("c")
36+
37+
a.sucessors.add(FieldAndNode(field, b))
38+
b.sucessors.add(FieldAndNode(field, c))
39+
c.sucessors.add(FieldAndNode(field, b))
40+
41+
val sccs = tarjanScc(listOf(a, b, c))
42+
println(sccs)
43+
}
44+
45+
@Test
46+
fun test2() {
47+
val a = node("a")
48+
val b = node("b")
49+
val c = node("c")
50+
val d = node("d")
51+
52+
a.sucessors.add(FieldAndNode(field, b))
53+
a.sucessors.add(FieldAndNode(field, d))
54+
b.sucessors.add(FieldAndNode(field, c))
55+
c.sucessors.add(FieldAndNode(field, b))
56+
d.sucessors.add(FieldAndNode(field, a))
57+
58+
val sccs = tarjanScc(listOf(a, b, c))
59+
println(sccs)
60+
}
61+
62+
}

libraries/apollo-ast/test-fixtures/validation/schema/input-cycles0.expected

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libraries/apollo-ast/test-fixtures/validation/schema/input-cycles0.graphqls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ type Query {
22
field(arg: SomeInputObject): String
33
}
44

5-
# simple cycle
5+
# Input object `SomeInputObject` references itself through field `SomeInputObject.nonNullSelf` and cannot be constructed.
66
input SomeInputObject {
77
nonNullSelf: SomeInputObject!
88
}

0 commit comments

Comments
 (0)