Skip to content

Commit dc32d4b

Browse files
authored
Merge pull request #2273 from Netflix/default-dgs-data-loader-provider
Avoid expensive lookup in DefaultDgsDataLoaderProvider
2 parents 661e035 + 029983a commit dc32d4b

File tree

2 files changed

+37
-43
lines changed

2 files changed

+37
-43
lines changed

graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DefaultDgsDataLoaderProvider.kt

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import org.slf4j.Logger
4242
import org.slf4j.LoggerFactory
4343
import org.springframework.aop.support.AopUtils
4444
import org.springframework.beans.factory.NoSuchBeanDefinitionException
45+
import org.springframework.beans.factory.getBeansWithAnnotation
4546
import org.springframework.context.ApplicationContext
4647
import org.springframework.context.ConfigurableApplicationContext
4748
import org.springframework.core.type.StandardMethodMetadata
@@ -99,7 +100,9 @@ class DefaultDgsDataLoaderProvider(
99100
mappedBatchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
100101
mappedBatchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
101102
}
102-
logger.debug("Created DGS dataloader registry in {}ms", totalTime)
103+
if (logger.isDebugEnabled) {
104+
logger.debug("Created DGS dataloader registry in {}ms", totalTime)
105+
}
103106
return registry
104107
}
105108

@@ -110,7 +113,7 @@ class DefaultDgsDataLoaderProvider(
110113
}
111114

112115
private fun addDataLoaderFields() {
113-
val dataLoaders = applicationContext.getBeansWithAnnotation(DgsComponent::class.java)
116+
val dataLoaders = applicationContext.getBeansWithAnnotation<DgsComponent>()
114117
dataLoaders.values.forEach { dgsComponent ->
115118
val javaClass = AopUtils.getTargetClass(dgsComponent)
116119

@@ -135,7 +138,7 @@ class DefaultDgsDataLoaderProvider(
135138
}
136139

137140
private fun addDataLoaderComponents() {
138-
val dataLoaders = applicationContext.getBeansWithAnnotation(DgsDataLoader::class.java)
141+
val dataLoaders = applicationContext.getBeansWithAnnotation<DgsDataLoader>()
139142
dataLoaders.forEach { (beanName, beanInstance) ->
140143
val javaClass = AopUtils.getTargetClass(beanInstance)
141144

@@ -178,18 +181,18 @@ class DefaultDgsDataLoaderProvider(
178181
annotation: DgsDataLoader,
179182
dispatchPredicate: DispatchPredicate? = null,
180183
) {
181-
if (dataLoaders.contains(dataLoaderName)) {
184+
if (dataLoaderName in dataLoaders) {
182185
throw MultipleDataLoadersDefinedException(dgsComponentClass, dataLoaders.getValue(dataLoaderName))
183186
}
184187
dataLoaders[dataLoaderName] = dgsComponentClass
185188

186189
fun <T : Any> createHolder(t: T): LoaderHolder<T> = LoaderHolder(t, annotation, dataLoaderName, dispatchPredicate)
187190

188191
when (val customizedDataLoader = runCustomizers(dataLoader, dataLoaderName, dgsComponentClass)) {
189-
is BatchLoader<*, *> -> batchLoaders.add(createHolder(customizedDataLoader))
190-
is BatchLoaderWithContext<*, *> -> batchLoadersWithContext.add(createHolder(customizedDataLoader))
191-
is MappedBatchLoader<*, *> -> mappedBatchLoaders.add(createHolder(customizedDataLoader))
192-
is MappedBatchLoaderWithContext<*, *> -> mappedBatchLoadersWithContext.add(createHolder(customizedDataLoader))
192+
is BatchLoader<*, *> -> batchLoaders += createHolder(customizedDataLoader)
193+
is BatchLoaderWithContext<*, *> -> batchLoadersWithContext += createHolder(customizedDataLoader)
194+
is MappedBatchLoader<*, *> -> mappedBatchLoaders += createHolder(customizedDataLoader)
195+
is MappedBatchLoaderWithContext<*, *> -> mappedBatchLoadersWithContext += createHolder(customizedDataLoader)
193196
else -> throw InvalidDataLoaderTypeException(dgsComponentClass)
194197
}
195198
}
@@ -198,23 +201,17 @@ class DefaultDgsDataLoaderProvider(
198201
originalDataLoader: Any,
199202
name: String,
200203
dgsComponentClass: Class<*>,
201-
): Any {
202-
var dataLoader = originalDataLoader
203-
204-
customizers.forEach {
205-
dataLoader =
206-
when (dataLoader) {
207-
is BatchLoader<*, *> -> it.provide(dataLoader as BatchLoader<*, *>, name)
208-
is BatchLoaderWithContext<*, *> -> it.provide(dataLoader as BatchLoaderWithContext<*, *>, name)
209-
is MappedBatchLoader<*, *> -> it.provide(dataLoader as MappedBatchLoader<*, *>, name)
210-
is MappedBatchLoaderWithContext<*, *> -> it.provide(dataLoader as MappedBatchLoaderWithContext<*, *>, name)
211-
else -> throw InvalidDataLoaderTypeException(dgsComponentClass)
212-
}
204+
): Any =
205+
customizers.fold(originalDataLoader) { dataLoader, customizer ->
206+
when (dataLoader) {
207+
is BatchLoader<*, *> -> customizer.provide(dataLoader, name)
208+
is BatchLoaderWithContext<*, *> -> customizer.provide(dataLoader, name)
209+
is MappedBatchLoader<*, *> -> customizer.provide(dataLoader, name)
210+
is MappedBatchLoaderWithContext<*, *> -> customizer.provide(dataLoader, name)
211+
else -> throw InvalidDataLoaderTypeException(dgsComponentClass)
212+
}
213213
}
214214

215-
return dataLoader
216-
}
217-
218215
private fun createDataLoader(
219216
batchLoader: BatchLoader<*, *>,
220217
dgsDataLoader: DgsDataLoader,
@@ -340,15 +337,11 @@ class DefaultDgsDataLoaderProvider(
340337
else -> throw IllegalArgumentException("Data loader ${holder.name} has unknown type")
341338
}
342339
// detect and throw an exception if multiple data loaders use the same name
343-
if (registry.keys.contains(holder.name)) {
340+
if (registry.getDataLoader<Any, Any>(holder.name) != null) {
344341
throw MultipleDataLoadersDefinedException(holder.theLoader.javaClass)
345342
}
346343

347-
if (holder.dispatchPredicate == null) {
348-
registry.register(holder.name, loader, DispatchPredicate.DISPATCH_ALWAYS)
349-
} else {
350-
registry.register(holder.name, loader, holder.dispatchPredicate)
351-
}
344+
registry.register(holder.name, loader, holder.dispatchPredicate ?: DispatchPredicate.DISPATCH_ALWAYS)
352345
}
353346

354347
private inline fun <reified T> wrappedDataLoader(

graphql-dgs/src/test/kotlin/com/netflix/graphql/dgs/DefaultDgsDataLoaderProviderTest.kt

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.junit.jupiter.api.Nested
3535
import org.junit.jupiter.api.Test
3636
import org.junit.jupiter.api.assertThrows
3737
import org.springframework.beans.factory.BeanCreationException
38+
import org.springframework.beans.factory.getBean
3839
import org.springframework.boot.test.context.runner.ApplicationContextRunner
3940
import java.util.concurrent.CompletableFuture
4041
import java.util.concurrent.CompletionStage
@@ -51,7 +52,7 @@ class DefaultDgsDataLoaderProviderTest {
5152
ExampleBatchLoader::class.java,
5253
).withBean(ExampleBatchLoaderWithDispatchPredicate::class.java)
5354
.run { context ->
54-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
55+
val provider = context.getBean<DgsDataLoaderProvider>()
5556
val dataLoaderRegistry = provider.buildRegistry()
5657
Assertions.assertEquals(2, dataLoaderRegistry.dataLoaders.size)
5758
val dataLoader = dataLoaderRegistry.getDataLoader<Any, Any>("exampleLoader")
@@ -68,7 +69,7 @@ class DefaultDgsDataLoaderProviderTest {
6869
ExampleBatchLoaderWithContext::class.java,
6970
).withBean(ExampleBatchLoaderWithContextAndDispatchPredicate::class.java)
7071
.run { context ->
71-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
72+
val provider = context.getBean<DgsDataLoaderProvider>()
7273
val dataLoaderRegistry = provider.buildRegistry()
7374
Assertions.assertEquals(2, dataLoaderRegistry.dataLoaders.size)
7475
val dataLoader = dataLoaderRegistry.getDataLoader<Any, Any>("exampleLoaderWithContext")
@@ -83,7 +84,7 @@ class DefaultDgsDataLoaderProviderTest {
8384
applicationContextRunner.withBean(ExampleBatchLoader::class.java).withBean(ExampleDuplicateBatchLoader::class.java).run { context ->
8485
val exc =
8586
assertThrows<IllegalStateException> {
86-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
87+
val provider = context.getBean<DgsDataLoaderProvider>()
8788
provider.buildRegistry()
8889
}
8990

@@ -103,7 +104,7 @@ class DefaultDgsDataLoaderProviderTest {
103104
.run { context ->
104105
val exc =
105106
assertThrows<IllegalStateException> {
106-
context.getBean(DgsDataLoaderProvider::class.java)
107+
context.getBean<DgsDataLoaderProvider>()
107108
}
108109
assertThat(exc.cause)
109110
.isInstanceOf(BeanCreationException::class.java)
@@ -115,7 +116,7 @@ class DefaultDgsDataLoaderProviderTest {
115116
@Test
116117
fun findDataLoadersFromFields() {
117118
applicationContextRunner.withBean(ExampleBatchLoaderFromField::class.java).run { context ->
118-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
119+
val provider = context.getBean<DgsDataLoaderProvider>()
119120
val dataLoaderRegistry = provider.buildRegistry()
120121
Assertions.assertEquals(2, dataLoaderRegistry.dataLoaders.size)
121122
val dataLoader = dataLoaderRegistry.getDataLoader<Any, Any>("exampleLoaderFromField")
@@ -133,7 +134,7 @@ class DefaultDgsDataLoaderProviderTest {
133134
ExampleMappedBatchLoader::class.java,
134135
).withBean(ExampleMappedBatchLoaderWithDispatchPredicate::class.java)
135136
.run { context ->
136-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
137+
val provider = context.getBean<DgsDataLoaderProvider>()
137138
val dataLoaderRegistry = provider.buildRegistry()
138139
Assertions.assertEquals(2, dataLoaderRegistry.dataLoaders.size)
139140
val dataLoader = dataLoaderRegistry.getDataLoader<Any, Any>("exampleMappedLoader")
@@ -150,7 +151,7 @@ class DefaultDgsDataLoaderProviderTest {
150151
ExampleMappedBatchLoaderWithContext::class.java,
151152
).withBean(ExampleMappedBatchLoaderWithContextAndDispatchPredicate::class.java)
152153
.run { context ->
153-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
154+
val provider = context.getBean<DgsDataLoaderProvider>()
154155
val dataLoaderRegistry = provider.buildRegistry()
155156
Assertions.assertEquals(2, dataLoaderRegistry.dataLoaders.size)
156157
val dataLoader = dataLoaderRegistry.getDataLoader<Any, Any>("exampleMappedLoaderWithContext")
@@ -163,7 +164,7 @@ class DefaultDgsDataLoaderProviderTest {
163164
@Test
164165
fun findMappedDataLoadersFromFields() {
165166
applicationContextRunner.withBean(ExampleMappedBatchLoaderFromField::class.java).run { context ->
166-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
167+
val provider = context.getBean<DgsDataLoaderProvider>()
167168
val dataLoaderRegistry = provider.buildRegistry()
168169
Assertions.assertEquals(2, dataLoaderRegistry.dataLoaders.size)
169170
val dataLoader = dataLoaderRegistry.getDataLoader<Any, Any>("exampleMappedLoaderFromField")
@@ -177,7 +178,7 @@ class DefaultDgsDataLoaderProviderTest {
177178
@Test
178179
fun dataLoaderConsumer() {
179180
applicationContextRunner.withBean(ExampleDataLoaderWithRegistry::class.java).run { context ->
180-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
181+
val provider = context.getBean<DgsDataLoaderProvider>()
181182
val registry = provider.buildRegistry()
182183

183184
// Use the dataloader's "load" method to check if the registry was set correctly, because the dataloader instance isn't itself a DgsDataLoaderRegistryConsumer
@@ -194,7 +195,7 @@ class DefaultDgsDataLoaderProviderTest {
194195
@Test
195196
fun findDataLoadersWithoutName() {
196197
applicationContextRunner.withBean(ExampleBatchLoaderWithoutName::class.java).run { context ->
197-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
198+
val provider = context.getBean<DgsDataLoaderProvider>()
198199
val dataLoaderRegistry = provider.buildRegistry()
199200
Assertions.assertEquals(1, dataLoaderRegistry.dataLoaders.size)
200201
val dataLoader =
@@ -206,7 +207,7 @@ class DefaultDgsDataLoaderProviderTest {
206207
@Test
207208
fun findDataLoadersWithoutNameByClass() {
208209
applicationContextRunner.withBean(ExampleBatchLoaderWithoutName::class.java).run { context ->
209-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
210+
val provider = context.getBean<DgsDataLoaderProvider>()
210211
val dataLoaderRegistry = provider.buildRegistry()
211212
Assertions.assertEquals(1, dataLoaderRegistry.dataLoaders.size)
212213
val dataLoader =
@@ -224,7 +225,7 @@ class DefaultDgsDataLoaderProviderTest {
224225
@Test
225226
fun findDataLoadersFromFieldsWithoutName() {
226227
applicationContextRunner.withBean(ExampleBatchLoaderWithoutNameFromField::class.java).run { context ->
227-
assertThatThrownBy { context.getBean(DgsDataLoaderProvider::class.java) }
228+
assertThatThrownBy { context.getBean<DgsDataLoaderProvider>() }
228229
.rootCause()
229230
.isInstanceOf(DgsUnnamedDataLoaderOnFieldException::class.java)
230231
.hasMessage(
@@ -241,10 +242,10 @@ class DefaultDgsDataLoaderProviderTest {
241242
).withBean(DgsWrapWithContextDataLoaderCustomizer::class.java)
242243
.withBean(DataLoaderCustomizerCounter::class.java)
243244
.run { context ->
244-
val provider = context.getBean(DgsDataLoaderProvider::class.java)
245+
val provider = context.getBean<DgsDataLoaderProvider>()
245246
val dataLoaderRegistry = provider.buildRegistry()
246247

247-
val counter = context.getBean(DataLoaderCustomizerCounter::class.java)
248+
val counter = context.getBean<DataLoaderCustomizerCounter>()
248249

249250
assertThat(dataLoaderRegistry.dataLoaders.size).isEqualTo(1)
250251

0 commit comments

Comments
 (0)