@@ -21,15 +21,23 @@ import com.expediagroup.graphql.generator.TopLevelObject
2121import com.expediagroup.graphql.generator.exceptions.GraphQLKotlinException
2222import com.expediagroup.graphql.generator.execution.GraphQLContext
2323import com.expediagroup.graphql.generator.toSchema
24+ import com.expediagroup.graphql.server.execution.DefaultDataLoaderRegistryFactory
25+ import com.expediagroup.graphql.server.execution.KotlinDataLoader
26+ import com.expediagroup.graphql.server.extensions.getValueFromDataLoader
2427import com.expediagroup.graphql.server.spring.subscriptions.SpringGraphQLSubscriptionHandler
2528import com.expediagroup.graphql.server.types.GraphQLRequest
2629import graphql.GraphQL
30+ import graphql.schema.DataFetchingEnvironment
2731import graphql.schema.GraphQLSchema
2832import io.mockk.mockk
33+ import org.dataloader.DataLoader
2934import org.junit.jupiter.api.Test
3035import reactor.core.publisher.Flux
36+ import reactor.kotlin.core.publisher.toFlux
37+ import reactor.kotlin.core.publisher.toMono
3138import reactor.test.StepVerifier
3239import java.time.Duration
40+ import java.util.concurrent.CompletableFuture
3341import kotlin.random.Random
3442import kotlin.test.assertEquals
3543import kotlin.test.assertNotNull
@@ -44,7 +52,16 @@ class SpringGraphQLSubscriptionHandlerTest {
4452 subscriptions = listOf (TopLevelObject (BasicSubscription ()))
4553 )
4654 private val testGraphQL: GraphQL = GraphQL .newGraphQL(testSchema).build()
47- private val subscriptionHandler = SpringGraphQLSubscriptionHandler (testGraphQL)
55+ private val mockLoader: KotlinDataLoader <String , String > = object : KotlinDataLoader <String , String > {
56+ override val dataLoaderName: String = " MockDataLoader"
57+ override fun getDataLoader (): DataLoader <String , String > = DataLoader <String , String > { ids ->
58+ CompletableFuture .supplyAsync {
59+ ids.map { " $it :value" }
60+ }
61+ }
62+ }
63+ private val dataLoaderRegistryFactory = DefaultDataLoaderRegistryFactory (listOf (mockLoader))
64+ private val subscriptionHandler = SpringGraphQLSubscriptionHandler (testGraphQL, dataLoaderRegistryFactory)
4865
4966 @Test
5067 fun `verify subscription` () {
@@ -64,6 +81,26 @@ class SpringGraphQLSubscriptionHandlerTest {
6481 .verify()
6582 }
6683
84+ @Test
85+ fun `verify subscription with data loader` () {
86+ val request = GraphQLRequest (query = " subscription { dataLoaderValue }" )
87+ val responseFlux = subscriptionHandler.executeSubscription(request, mockk())
88+
89+ StepVerifier .create(responseFlux)
90+ .thenConsumeWhile { response ->
91+ assertNotNull(response.data as ? Map <* , * >) { data ->
92+ assertNotNull(data[" dataLoaderValue" ] as ? String ) { value ->
93+ assertEquals(" foo:value" , value)
94+ }
95+ }
96+ assertNull(response.errors)
97+ assertNull(response.extensions)
98+ true
99+ }
100+ .expectComplete()
101+ .verify()
102+ }
103+
67104 @Test
68105 fun `verify subscription with context` () {
69106 val request = GraphQLRequest (query = " subscription { contextualTicker }" )
@@ -122,6 +159,8 @@ class SpringGraphQLSubscriptionHandlerTest {
122159 fun contextualTicker (context : SubscriptionContext ): Flux <String > = Flux .range(1 , 5 )
123160 .delayElements(Duration .ofMillis(100 ))
124161 .map { " ${context.value} :${Random .nextInt(100 )} " }
162+
163+ fun dataLoaderValue (dfe : DataFetchingEnvironment ): Flux <String > = dfe.getValueFromDataLoader<String , String >(" MockDataLoader" , " foo" ).toMono().toFlux()
125164 }
126165
127166 data class SubscriptionContext (val value : String ) : GraphQLContext
0 commit comments