diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/SourceArgumentResolver.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/SourceArgumentResolver.kt index fcd00ca6f..bde571a0d 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/SourceArgumentResolver.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/SourceArgumentResolver.kt @@ -32,7 +32,7 @@ class SourceArgumentResolver : ArgumentResolver { throw IllegalArgumentException("Source is null. Are you trying to use @Source on a root field (e.g. @DgsQuery)?") } - if (parameter.parameterType == source.javaClass) { + if (parameter.parameterType.isAssignableFrom(source.javaClass)) { return source } else { throw IllegalArgumentException( diff --git a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/SourceArgumentTest.kt b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/SourceArgumentTest.kt index 13228e3ac..ab76b3952 100644 --- a/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/SourceArgumentTest.kt +++ b/graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/internal/SourceArgumentTest.kt @@ -52,9 +52,17 @@ internal class SourceArgumentTest { ), ) + interface Entertainment { + val title: String + } + data class Show( - val title: String, - ) + override val title: String, + ) : Entertainment + + data class Movie( + override val title: String, + ) : Entertainment @Test fun `@Source argument`() { @@ -96,6 +104,60 @@ internal class SourceArgumentTest { } } + @Test + fun `@Source argument with interface type`() { + @DgsComponent + class Fetcher { + @DgsQuery + fun shows(): List = listOf(Show("Stranger Things")) + + @DgsQuery + fun movies(): List = listOf(Movie("Batman")) + + @DgsData(parentType = "Show") + @DgsData(parentType = "Movie") + fun description( + @Source entertainment: Entertainment, + ): String = "Description of ${entertainment.title}" + } + + contextRunner.withBean(Fetcher::class.java).run { context -> + val provider = schemaProvider(context) + val schema = provider.schema().graphQLSchema + + val build = GraphQL.newGraphQL(schema).build() + val executionResult = + build.execute( + """{ + | shows { + | title + | description + | } + | + | movies { + | title + | description + | } + |} + """.trimMargin(), + ) + + assertThat(executionResult.errors).isEmpty() + assertThat(executionResult.isDataPresent).isTrue + val data = executionResult.getData>() + + @Suppress("UNCHECKED_CAST") + val showData = (data["shows"] as List>)[0] + assertThat(showData["title"]).isEqualTo("Stranger Things") + assertThat(showData["description"]).isEqualTo("Description of Stranger Things") + + @Suppress("UNCHECKED_CAST") + val movieData = (data["movies"] as List>)[0] + assertThat(movieData["title"]).isEqualTo("Batman") + assertThat(movieData["description"]).isEqualTo("Description of Batman") + } + } + @Test fun `Incorrect @Source argument type`() { @DgsComponent diff --git a/graphql-dgs/src/test/resources/source-argument-test/schema.graphqls b/graphql-dgs/src/test/resources/source-argument-test/schema.graphqls index a66266a5f..54158e79a 100644 --- a/graphql-dgs/src/test/resources/source-argument-test/schema.graphqls +++ b/graphql-dgs/src/test/resources/source-argument-test/schema.graphqls @@ -1,8 +1,14 @@ type Query { shows: [Show] + movies: [Movie] } type Show { title: String description: String +} + +type Movie { + title: String + description: String } \ No newline at end of file