Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ open class DgsContext(
@JvmStatic
fun from(graphQLContext: GraphQLContext): DgsContext = graphQLContext[GraphQLContextKey.DGS_CONTEXT_KEY]

/**
* Safely retrieves DgsContext from GraphQLContext, returning null if not present.
* This is useful in scenarios where DgsContext may not have been initialized yet,
* such as during subscription callback setup with Apollo Federation.
*
* @param graphQLContext The GraphQL context to retrieve DgsContext from
* @return DgsContext if present, null otherwise
*/
@JvmStatic
fun fromOrNull(graphQLContext: GraphQLContext): DgsContext? =
graphQLContext.getOrDefault(GraphQLContextKey.DGS_CONTEXT_KEY, null)

@JvmStatic
fun from(dfe: DataFetchingEnvironment): DgsContext = from(dfe.graphQlContext)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ class GraphQLContextContributorInstrumentation(
val graphqlContext = parameters.executionInput.graphQLContext
if (graphqlContext != null && graphQLContextContributors.isNotEmpty()) {
val extensions = parameters.executionInput.extensions
val requestData = DgsContext.from(graphqlContext).requestData

// Use null-safe access because DgsContext may not be available yet during
// subscription callback initialization (Apollo Federation HTTP callback protocol).
// The CallbackWebGraphQLInterceptor runs at LOWEST_PRECEDENCE, so DgsContext
// won't be set until after this instrumentation's createState() is called.
val requestData = DgsContext.fromOrNull(graphqlContext)?.requestData

val builderForContributors = GraphQLContext.newContext()
graphQLContextContributors.forEach { it.contribute(builderForContributors, extensions, requestData) }
graphqlContext.putAll(builderForContributors)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2025 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs.context

import graphql.ExecutionInput
import graphql.GraphQLContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows

class DgsContextTest {

private fun buildGraphQLContextWithDgsContext(dgsContext: DgsContext): GraphQLContext {
// Build ExecutionInput with DgsContext as consumer to properly populate the GraphQLContext
val executionInput = ExecutionInput.newExecutionInput()
.query("{ __typename }")
.graphQLContext(dgsContext)
.build()
return executionInput.graphQLContext
}

@Test
fun `fromOrNull should return null when DgsContext is not present`() {
val graphQLContext = GraphQLContext.newContext().build()

val result = DgsContext.fromOrNull(graphQLContext)

assertThat(result).isNull()
}

@Test
fun `fromOrNull should return DgsContext when present`() {
val dgsContext = DgsContext(customContext = "testContext", requestData = null)
val graphQLContext = buildGraphQLContextWithDgsContext(dgsContext)

val result = DgsContext.fromOrNull(graphQLContext)

assertThat(result).isNotNull
assertThat(result?.customContext).isEqualTo("testContext")
}

@Test
fun `from should throw NullPointerException when DgsContext is not present`() {
val graphQLContext = GraphQLContext.newContext().build()

assertThrows<NullPointerException> {
DgsContext.from(graphQLContext)
}
}

@Test
fun `from should return DgsContext when present`() {
val dgsContext = DgsContext(customContext = "testContext", requestData = null)
val graphQLContext = buildGraphQLContextWithDgsContext(dgsContext)

val result = DgsContext.from(graphQLContext)

assertThat(result).isNotNull
assertThat(result.customContext).isEqualTo("testContext")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2025 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs.context

import com.netflix.graphql.dgs.internal.DgsRequestData
import graphql.ExecutionInput
import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters
import graphql.schema.GraphQLSchema
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

class GraphQLContextContributorInstrumentationTest {

@Test
fun `createState should not throw NPE when DgsContext is not present`() {
val contributor = mockk<GraphQLContextContributor>(relaxed = true)
val instrumentation = GraphQLContextContributorInstrumentation(listOf(contributor))

val executionInput = ExecutionInput.newExecutionInput()
.query("{ __typename }")
.build()
val schema = mockk<GraphQLSchema>()
val parameters = InstrumentationCreateStateParameters(schema, executionInput)

val result = instrumentation.createState(parameters)

assertThat(result).isNull()
verify { contributor.contribute(any(), any(), null) }
}

@Test
fun `createState should pass requestData when DgsContext is present`() {
val contributor = mockk<GraphQLContextContributor>(relaxed = true)
val instrumentation = GraphQLContextContributorInstrumentation(listOf(contributor))

val requestData = mockk<DgsRequestData>()
val dgsContext = DgsContext(customContext = null, requestData = requestData)
val executionInput = ExecutionInput.newExecutionInput()
.query("{ __typename }")
.graphQLContext(dgsContext)
.build()
val schema = mockk<GraphQLSchema>()
val parameters = InstrumentationCreateStateParameters(schema, executionInput)

val result = instrumentation.createState(parameters)

assertThat(result).isNull()
verify { contributor.contribute(any(), any(), requestData) }
}

@Test
fun `createState should skip contributors when list is empty`() {
val instrumentation = GraphQLContextContributorInstrumentation(emptyList())

val executionInput = ExecutionInput.newExecutionInput()
.query("{ __typename }")
.build()
val schema = mockk<GraphQLSchema>()
val parameters = InstrumentationCreateStateParameters(schema, executionInput)

val result = instrumentation.createState(parameters)

assertThat(result).isNull()
}

@Test
fun `createState should handle null graphQLContext`() {
val contributor = mockk<GraphQLContextContributor>(relaxed = true)
val instrumentation = GraphQLContextContributorInstrumentation(listOf(contributor))

val executionInput = mockk<ExecutionInput>()
every { executionInput.graphQLContext } returns null
val schema = mockk<GraphQLSchema>()
val parameters = InstrumentationCreateStateParameters(schema, executionInput)

val result = instrumentation.createState(parameters)

assertThat(result).isNull()
verify(exactly = 0) { contributor.contribute(any(), any(), any()) }
}
}
Loading