Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphql-dgs-spring-graphql/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dependencies {
implementation(project(":graphql-dgs"))
implementation(project(":graphql-dgs-reactive"))
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core")
implementation("org.springframework:spring-web")
implementation("org.springframework.boot:spring-boot-autoconfigure")
implementation("io.micrometer:context-propagation")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ import graphql.schema.idl.TypeDefinitionRegistry
import io.micrometer.context.ContextRegistry
import io.micrometer.context.ContextSnapshotFactory
import io.micrometer.context.integration.Slf4jThreadLocalAccessor
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import org.reactivestreams.Publisher
import org.slf4j.Logger
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -484,10 +486,21 @@ open class DgsSpringGraphQLAutoConfiguration(
return executor
}

/**
* Default CoroutineDispatcher used for executing Kotlin suspend functions in data fetchers.
* Defaults to [Dispatchers.Unconfined] which runs coroutines immediately on the calling thread.
* Override this bean to customize the dispatcher for your specific use case.
*/
@Bean
@Qualifier("dgsCoroutineDispatcher")
@ConditionalOnMissingBean(name = ["dgsCoroutineDispatcher"])
open fun dgsCoroutineDispatcher(): CoroutineDispatcher = Dispatchers.Unconfined

@Bean
open fun methodDataFetcherFactory(
argumentResolvers: ObjectProvider<ArgumentResolver>,
@Qualifier("dgsAsyncTaskExecutor") taskExecutorOptional: Optional<AsyncTaskExecutor>,
@Qualifier("dgsCoroutineDispatcher") coroutineDispatcher: CoroutineDispatcher,
): MethodDataFetcherFactory {
val taskExecutor =
if (taskExecutorOptional.isPresent) {
Expand All @@ -496,7 +509,12 @@ open class DgsSpringGraphQLAutoConfiguration(
null
}

return MethodDataFetcherFactory(argumentResolvers.orderedStream().toList(), DefaultParameterNameDiscoverer(), taskExecutor)
return MethodDataFetcherFactory(
argumentResolvers.orderedStream().toList(),
DefaultParameterNameDiscoverer(),
taskExecutor,
coroutineDispatcher,
)
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.netflix.graphql.dgs.internal
import com.netflix.graphql.dgs.internal.method.ArgumentResolverComposite
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.reactor.mono
import org.springframework.core.BridgeMethodResolver
Expand All @@ -43,6 +44,7 @@ class DataFetcherInvoker internal constructor(
private val resolvers: ArgumentResolverComposite,
parameterNameDiscoverer: ParameterNameDiscoverer,
taskExecutor: AsyncTaskExecutor?,
private val coroutineDispatcher: CoroutineDispatcher = Dispatchers.Unconfined,
) : DataFetcher<Any?> {
private val bridgedMethod: Method = BridgeMethodResolver.findBridgedMethod(method)
private val kotlinFunction: KFunction<*>? =
Expand Down Expand Up @@ -131,7 +133,7 @@ class DataFetcherInvoker internal constructor(
}

if (kFunc.isSuspend) {
return mono(Dispatchers.Unconfined) {
return mono(coroutineDispatcher) {
kFunc.callSuspendBy(argsByName)
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import graphql.TrivialDataFetcher
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import graphql.schema.FieldCoordinates
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import org.springframework.core.DefaultParameterNameDiscoverer
import org.springframework.core.MethodParameter
import org.springframework.core.ParameterNameDiscoverer
Expand All @@ -39,6 +41,7 @@ class MethodDataFetcherFactory(
argumentResolvers: List<ArgumentResolver>,
internal val parameterNameDiscoverer: ParameterNameDiscoverer = DefaultParameterNameDiscoverer(),
private val asyncTaskExecutor: AsyncTaskExecutor? = null,
private val coroutineDispatcher: CoroutineDispatcher = Dispatchers.Unconfined,
) {
private val resolvers = ArgumentResolverComposite(argumentResolvers)

Expand All @@ -55,6 +58,7 @@ class MethodDataFetcherFactory(
resolvers = resolvers,
parameterNameDiscoverer = parameterNameDiscoverer,
taskExecutor = null,
coroutineDispatcher = coroutineDispatcher,
)
return object : TrivialDataFetcher<Any?> {
override fun get(environment: DataFetchingEnvironment): Any? = methodDataFetcher.get(environment)
Expand All @@ -69,6 +73,7 @@ class MethodDataFetcherFactory(
resolvers = resolvers,
parameterNameDiscoverer = parameterNameDiscoverer,
taskExecutor = asyncTaskExecutor,
coroutineDispatcher = coroutineDispatcher,
)
}

Expand Down
Loading