diff --git a/graphql-dgs-spring-graphql/build.gradle.kts b/graphql-dgs-spring-graphql/build.gradle.kts index 9b4205efe..1a3e6138b 100644 --- a/graphql-dgs-spring-graphql/build.gradle.kts +++ b/graphql-dgs-spring-graphql/build.gradle.kts @@ -17,6 +17,7 @@ dependencies { implementation(project(":graphql-dgs")) implementation(project(":graphql-dgs-reactive")) + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core") implementation("org.springframework:spring-web") implementation("org.springframework.boot:spring-boot-autoconfigure") implementation("io.micrometer:context-propagation") diff --git a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/autoconfig/DgsSpringGraphQLAutoConfiguration.kt b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/autoconfig/DgsSpringGraphQLAutoConfiguration.kt index 6263628fe..32661fb5e 100644 --- a/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/autoconfig/DgsSpringGraphQLAutoConfiguration.kt +++ b/graphql-dgs-spring-graphql/src/main/kotlin/com/netflix/graphql/dgs/springgraphql/autoconfig/DgsSpringGraphQLAutoConfiguration.kt @@ -85,6 +85,8 @@ import graphql.schema.idl.TypeDefinitionRegistry import io.micrometer.context.ContextRegistry import io.micrometer.context.ContextSnapshotFactory import io.micrometer.context.integration.Slf4jThreadLocalAccessor +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Dispatchers import org.reactivestreams.Publisher import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -484,10 +486,21 @@ open class DgsSpringGraphQLAutoConfiguration( return executor } + /** + * Default CoroutineDispatcher used for executing Kotlin suspend functions in data fetchers. + * Defaults to [Dispatchers.Unconfined] which runs coroutines immediately on the calling thread. + * Override this bean to customize the dispatcher for your specific use case. + */ + @Bean + @Qualifier("dgsCoroutineDispatcher") + @ConditionalOnMissingBean(name = ["dgsCoroutineDispatcher"]) + open fun dgsCoroutineDispatcher(): CoroutineDispatcher = Dispatchers.Unconfined + @Bean open fun methodDataFetcherFactory( argumentResolvers: ObjectProvider, @Qualifier("dgsAsyncTaskExecutor") taskExecutorOptional: Optional, + @Qualifier("dgsCoroutineDispatcher") coroutineDispatcher: CoroutineDispatcher, ): MethodDataFetcherFactory { val taskExecutor = if (taskExecutorOptional.isPresent) { @@ -496,7 +509,12 @@ open class DgsSpringGraphQLAutoConfiguration( null } - return MethodDataFetcherFactory(argumentResolvers.orderedStream().toList(), DefaultParameterNameDiscoverer(), taskExecutor) + return MethodDataFetcherFactory( + argumentResolvers.orderedStream().toList(), + DefaultParameterNameDiscoverer(), + taskExecutor, + coroutineDispatcher, + ) } @Bean diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DataFetcherInvoker.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DataFetcherInvoker.kt index 3af33cb4d..f3e1b3b5e 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DataFetcherInvoker.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DataFetcherInvoker.kt @@ -19,6 +19,7 @@ package com.netflix.graphql.dgs.internal import com.netflix.graphql.dgs.internal.method.ArgumentResolverComposite import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment +import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.reactor.mono import org.springframework.core.BridgeMethodResolver @@ -43,6 +44,7 @@ class DataFetcherInvoker internal constructor( private val resolvers: ArgumentResolverComposite, parameterNameDiscoverer: ParameterNameDiscoverer, taskExecutor: AsyncTaskExecutor?, + private val coroutineDispatcher: CoroutineDispatcher = Dispatchers.Unconfined, ) : DataFetcher { private val bridgedMethod: Method = BridgeMethodResolver.findBridgedMethod(method) private val kotlinFunction: KFunction<*>? = @@ -131,7 +133,7 @@ class DataFetcherInvoker internal constructor( } if (kFunc.isSuspend) { - return mono(Dispatchers.Unconfined) { + return mono(coroutineDispatcher) { kFunc.callSuspendBy(argsByName) }.onErrorMap(InvocationTargetException::class.java) { it.targetException } } diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/MethodDataFetcherFactory.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/MethodDataFetcherFactory.kt index 4af11a6bf..559b752d1 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/MethodDataFetcherFactory.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/method/MethodDataFetcherFactory.kt @@ -22,6 +22,8 @@ import graphql.TrivialDataFetcher import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import graphql.schema.FieldCoordinates +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Dispatchers import org.springframework.core.DefaultParameterNameDiscoverer import org.springframework.core.MethodParameter import org.springframework.core.ParameterNameDiscoverer @@ -39,6 +41,7 @@ class MethodDataFetcherFactory( argumentResolvers: List, internal val parameterNameDiscoverer: ParameterNameDiscoverer = DefaultParameterNameDiscoverer(), private val asyncTaskExecutor: AsyncTaskExecutor? = null, + private val coroutineDispatcher: CoroutineDispatcher = Dispatchers.Unconfined, ) { private val resolvers = ArgumentResolverComposite(argumentResolvers) @@ -55,6 +58,7 @@ class MethodDataFetcherFactory( resolvers = resolvers, parameterNameDiscoverer = parameterNameDiscoverer, taskExecutor = null, + coroutineDispatcher = coroutineDispatcher, ) return object : TrivialDataFetcher { override fun get(environment: DataFetchingEnvironment): Any? = methodDataFetcher.get(environment) @@ -69,6 +73,7 @@ class MethodDataFetcherFactory( resolvers = resolvers, parameterNameDiscoverer = parameterNameDiscoverer, taskExecutor = asyncTaskExecutor, + coroutineDispatcher = coroutineDispatcher, ) }