Skip to content

Commit 0c5d6ba

Browse files
smyrickShane Myrick
andauthored
Add data loader registry to subscription execution (#1085)
Co-authored-by: Shane Myrick <[email protected]>
1 parent 1b35b96 commit 0c5d6ba

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/SubscriptionAutoConfiguration.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.expediagroup.graphql.server.spring
1818

19+
import com.expediagroup.graphql.server.execution.DataLoaderRegistryFactory
1920
import com.expediagroup.graphql.server.operations.Subscription
2021
import com.expediagroup.graphql.server.spring.subscriptions.ApolloSubscriptionHooks
2122
import com.expediagroup.graphql.server.spring.subscriptions.ApolloSubscriptionProtocolHandler
@@ -55,7 +56,10 @@ class SubscriptionAutoConfiguration {
5556

5657
@Bean
5758
@ConditionalOnMissingBean
58-
fun subscriptionHandler(graphQL: GraphQL) = SpringGraphQLSubscriptionHandler(graphQL)
59+
fun subscriptionHandler(
60+
graphQL: GraphQL,
61+
dataLoaderRegistryFactory: DataLoaderRegistryFactory
62+
) = SpringGraphQLSubscriptionHandler(graphQL, dataLoaderRegistryFactory)
5963

6064
@Bean
6165
@ConditionalOnMissingBean

servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SpringGraphQLSubscriptionHandler.kt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package com.expediagroup.graphql.server.spring.subscriptions
1818

1919
import com.expediagroup.graphql.generator.execution.GraphQLContext
2020
import com.expediagroup.graphql.server.exception.KotlinGraphQLError
21+
import com.expediagroup.graphql.server.execution.DataLoaderRegistryFactory
2122
import com.expediagroup.graphql.server.extensions.toExecutionInput
2223
import com.expediagroup.graphql.server.extensions.toGraphQLKotlinType
2324
import com.expediagroup.graphql.server.extensions.toGraphQLResponse
@@ -32,15 +33,22 @@ import reactor.kotlin.core.publisher.toFlux
3233
/**
3334
* Default Spring implementation of GraphQL subscription handler.
3435
*/
35-
open class SpringGraphQLSubscriptionHandler(private val graphQL: GraphQL) {
36+
open class SpringGraphQLSubscriptionHandler(
37+
private val graphQL: GraphQL,
38+
private val dataLoaderRegistryFactory: DataLoaderRegistryFactory? = null
39+
) {
3640

37-
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flux<GraphQLResponse<*>> =
38-
graphQL.execute(graphQLRequest.toExecutionInput(graphQLContext))
41+
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flux<GraphQLResponse<*>> {
42+
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
43+
val input = graphQLRequest.toExecutionInput(graphQLContext, dataLoaderRegistry)
44+
45+
return graphQL.execute(input)
3946
.getData<Publisher<ExecutionResult>>()
4047
.toFlux()
4148
.map { result -> result.toGraphQLResponse() }
4249
.onErrorResume { throwable ->
4350
val error = KotlinGraphQLError(throwable).toGraphQLKotlinType()
4451
Flux.just(GraphQLResponse<Any?>(errors = listOf(error)))
4552
}
53+
}
4654
}

servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLSubscriptionHandlerTest.kt

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,23 @@ import com.expediagroup.graphql.generator.TopLevelObject
2121
import com.expediagroup.graphql.generator.exceptions.GraphQLKotlinException
2222
import com.expediagroup.graphql.generator.execution.GraphQLContext
2323
import com.expediagroup.graphql.generator.toSchema
24+
import com.expediagroup.graphql.server.execution.DefaultDataLoaderRegistryFactory
25+
import com.expediagroup.graphql.server.execution.KotlinDataLoader
26+
import com.expediagroup.graphql.server.extensions.getValueFromDataLoader
2427
import com.expediagroup.graphql.server.spring.subscriptions.SpringGraphQLSubscriptionHandler
2528
import com.expediagroup.graphql.server.types.GraphQLRequest
2629
import graphql.GraphQL
30+
import graphql.schema.DataFetchingEnvironment
2731
import graphql.schema.GraphQLSchema
2832
import io.mockk.mockk
33+
import org.dataloader.DataLoader
2934
import org.junit.jupiter.api.Test
3035
import reactor.core.publisher.Flux
36+
import reactor.kotlin.core.publisher.toFlux
37+
import reactor.kotlin.core.publisher.toMono
3138
import reactor.test.StepVerifier
3239
import java.time.Duration
40+
import java.util.concurrent.CompletableFuture
3341
import kotlin.random.Random
3442
import kotlin.test.assertEquals
3543
import kotlin.test.assertNotNull
@@ -44,7 +52,16 @@ class SpringGraphQLSubscriptionHandlerTest {
4452
subscriptions = listOf(TopLevelObject(BasicSubscription()))
4553
)
4654
private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema).build()
47-
private val subscriptionHandler = SpringGraphQLSubscriptionHandler(testGraphQL)
55+
private val mockLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
56+
override val dataLoaderName: String = "MockDataLoader"
57+
override fun getDataLoader(): DataLoader<String, String> = DataLoader<String, String> { ids ->
58+
CompletableFuture.supplyAsync {
59+
ids.map { "$it:value" }
60+
}
61+
}
62+
}
63+
private val dataLoaderRegistryFactory = DefaultDataLoaderRegistryFactory(listOf(mockLoader))
64+
private val subscriptionHandler = SpringGraphQLSubscriptionHandler(testGraphQL, dataLoaderRegistryFactory)
4865

4966
@Test
5067
fun `verify subscription`() {
@@ -64,6 +81,26 @@ class SpringGraphQLSubscriptionHandlerTest {
6481
.verify()
6582
}
6683

84+
@Test
85+
fun `verify subscription with data loader`() {
86+
val request = GraphQLRequest(query = "subscription { dataLoaderValue }")
87+
val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
88+
89+
StepVerifier.create(responseFlux)
90+
.thenConsumeWhile { response ->
91+
assertNotNull(response.data as? Map<*, *>) { data ->
92+
assertNotNull(data["dataLoaderValue"] as? String) { value ->
93+
assertEquals("foo:value", value)
94+
}
95+
}
96+
assertNull(response.errors)
97+
assertNull(response.extensions)
98+
true
99+
}
100+
.expectComplete()
101+
.verify()
102+
}
103+
67104
@Test
68105
fun `verify subscription with context`() {
69106
val request = GraphQLRequest(query = "subscription { contextualTicker }")
@@ -122,6 +159,8 @@ class SpringGraphQLSubscriptionHandlerTest {
122159
fun contextualTicker(context: SubscriptionContext): Flux<String> = Flux.range(1, 5)
123160
.delayElements(Duration.ofMillis(100))
124161
.map { "${context.value}:${Random.nextInt(100)}" }
162+
163+
fun dataLoaderValue(dfe: DataFetchingEnvironment): Flux<String> = dfe.getValueFromDataLoader<String, String>("MockDataLoader", "foo").toMono().toFlux()
125164
}
126165

127166
data class SubscriptionContext(val value: String) : GraphQLContext

0 commit comments

Comments
 (0)