diff --git a/.changesets/feat_forward_headers.md b/.changesets/feat_forward_headers.md new file mode 100644 index 00000000..91ddc5ed --- /dev/null +++ b/.changesets/feat_forward_headers.md @@ -0,0 +1,14 @@ +### Add support for forwarding headers from MCP clients to GraphQL APIs - @DaleSeo PR #428 + +Adds opt-in support for dynamic header forwarding, which enables metadata for A/B testing, feature flagging, geo information from CDNs, or internal instrumentation to be sent from MCP clients to downstream GraphQL APIs. It automatically blocks hop-by-hop headers according to the guidelines in [RFC 7230, section 6.1](https://datatracker.ietf.org/doc/html/rfc7230#section-6.1), and it only works with the Streamable HTTP transport. + +You can configure using the `forward_headers` setting: + +```yaml +forward_headers: + - x-tenant-id + - x-experiment-id + - x-geo-country +``` + +Please note that this feature is not intended for passing through credentials as documented in the best practices page. \ No newline at end of file diff --git a/crates/apollo-mcp-server/src/auth/valid_token.rs b/crates/apollo-mcp-server/src/auth/valid_token.rs index 84ce0bab..f880f29e 100644 --- a/crates/apollo-mcp-server/src/auth/valid_token.rs +++ b/crates/apollo-mcp-server/src/auth/valid_token.rs @@ -12,7 +12,7 @@ use url::Url; /// Note: This is used as a marker to ensure that we have validated this /// separately from just reading the header itself. #[derive(Clone, Debug, PartialEq)] -pub(crate) struct ValidToken(pub(super) Authorization); +pub(crate) struct ValidToken(pub(crate) Authorization); impl Deref for ValidToken { type Target = Authorization; diff --git a/crates/apollo-mcp-server/src/headers.rs b/crates/apollo-mcp-server/src/headers.rs new file mode 100644 index 00000000..d028602d --- /dev/null +++ b/crates/apollo-mcp-server/src/headers.rs @@ -0,0 +1,296 @@ +use std::ops::Deref; +use std::str::FromStr; + +use headers::HeaderMapExt; +use http::Extensions; +use reqwest::header::{HeaderMap, HeaderName}; + +use crate::auth::ValidToken; + +/// List of header names to forward from MCP clients to GraphQL API +pub type ForwardHeaders = Vec; + +/// Build headers for a GraphQL request by combining static headers with forwarded headers +pub fn build_request_headers( + static_headers: &HeaderMap, + forward_header_names: &ForwardHeaders, + incoming_headers: &HeaderMap, + extensions: &Extensions, + disable_auth_token_passthrough: bool, +) -> HeaderMap { + // Starts with static headers + let mut headers = static_headers.clone(); + + // Forward headers dynamically + forward_headers(forward_header_names, incoming_headers, &mut headers); + + // Optionally extract the validated token and propagate it to upstream servers if present + if !disable_auth_token_passthrough && let Some(token) = extensions.get::() { + headers.typed_insert(token.deref().clone()); + } + + // Forward the mcp-session-id header if present + if let Some(session_id) = incoming_headers.get("mcp-session-id") { + headers.insert("mcp-session-id", session_id.clone()); + } + + headers +} + +/// Forward matching headers from incoming headers to outgoing headers +fn forward_headers(names: &[String], incoming: &HeaderMap, outgoing: &mut HeaderMap) { + for header in names { + if let Ok(header_name) = HeaderName::from_str(header) + && let Some(value) = incoming.get(&header_name) + // Hop-by-hop headers are blocked per RFC 7230: https://datatracker.ietf.org/doc/html/rfc7230#section-6.1 + && !matches!( + header_name.as_str().to_lowercase().as_str(), + "connection" + | "keep-alive" + | "proxy-authenticate" + | "proxy-authorization" + | "te" + | "trailers" + | "transfer-encoding" + | "upgrade" + | "content-length" + ) + { + outgoing.insert(header_name, value.clone()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use headers::Authorization; + use http::Extensions; + use reqwest::header::HeaderValue; + + use crate::auth::ValidToken; + + #[test] + fn test_build_request_headers_includes_static_headers() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("static-key")); + static_headers.insert("user-agent", HeaderValue::from_static("mcp-server")); + + let forward_header_names = vec![]; + let incoming_headers = HeaderMap::new(); + let extensions = Extensions::new(); + + let result = build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + ); + + assert_eq!(result.get("x-api-key").unwrap(), "static-key"); + assert_eq!(result.get("user-agent").unwrap(), "mcp-server"); + } + + #[test] + fn test_build_request_headers_forwards_configured_headers() { + let static_headers = HeaderMap::new(); + let forward_header_names = vec!["x-tenant-id".to_string(), "x-trace-id".to_string()]; + + let mut incoming_headers = HeaderMap::new(); + incoming_headers.insert("x-tenant-id", HeaderValue::from_static("tenant-123")); + incoming_headers.insert("x-trace-id", HeaderValue::from_static("trace-456")); + incoming_headers.insert("other-header", HeaderValue::from_static("ignored")); + + let extensions = Extensions::new(); + + let result = build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + ); + + assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123"); + assert_eq!(result.get("x-trace-id").unwrap(), "trace-456"); + assert!(result.get("other-header").is_none()); + } + + #[test] + fn test_build_request_headers_adds_oauth_token_when_enabled() { + let static_headers = HeaderMap::new(); + let forward_header_names = vec![]; + let incoming_headers = HeaderMap::new(); + + let mut extensions = Extensions::new(); + let token = ValidToken(Authorization::bearer("test-token").unwrap()); + extensions.insert(token); + + let result = build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + ); + + assert!(result.get("authorization").is_some()); + assert_eq!(result.get("authorization").unwrap(), "Bearer test-token"); + } + + #[test] + fn test_build_request_headers_skips_oauth_token_when_disabled() { + let static_headers = HeaderMap::new(); + let forward_header_names = vec![]; + let incoming_headers = HeaderMap::new(); + + let mut extensions = Extensions::new(); + let token = ValidToken(Authorization::bearer("test-token").unwrap()); + extensions.insert(token); + + let result = build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + true, + ); + + assert!(result.get("authorization").is_none()); + } + + #[test] + fn test_build_request_headers_forwards_mcp_session_id() { + let static_headers = HeaderMap::new(); + let forward_header_names = vec![]; + + let mut incoming_headers = HeaderMap::new(); + incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-123")); + + let extensions = Extensions::new(); + + let result = build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + ); + + assert_eq!(result.get("mcp-session-id").unwrap(), "session-123"); + } + + #[test] + fn test_build_request_headers_combined_scenario() { + // Static headers + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("static-key")); + + // Forward specific headers + let forward_header_names = vec!["x-tenant-id".to_string()]; + + // Incoming headers + let mut incoming_headers = HeaderMap::new(); + incoming_headers.insert("x-tenant-id", HeaderValue::from_static("tenant-123")); + incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-456")); + incoming_headers.insert( + "ignored-header", + HeaderValue::from_static("should-not-appear"), + ); + + // OAuth token + let mut extensions = Extensions::new(); + let token = ValidToken(Authorization::bearer("oauth-token").unwrap()); + extensions.insert(token); + + let result = build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + ); + + // Verify all parts combined correctly + assert_eq!(result.get("x-api-key").unwrap(), "static-key"); + assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123"); + assert_eq!(result.get("mcp-session-id").unwrap(), "session-456"); + assert_eq!(result.get("authorization").unwrap(), "Bearer oauth-token"); + assert!(result.get("ignored-header").is_none()); + } + + #[test] + fn test_forward_headers_no_headers_by_default() { + let names: Vec = vec![]; + + let mut incoming = HeaderMap::new(); + incoming.insert("x-tenant-id", HeaderValue::from_static("tenant-123")); + + let mut outgoing = HeaderMap::new(); + + forward_headers(&names, &incoming, &mut outgoing); + + assert!(outgoing.is_empty()); + } + + #[test] + fn test_forward_headers_only_specific_headers() { + let names = vec![ + "x-tenant-id".to_string(), // Multi-tenancy + "x-trace-id".to_string(), // Distributed tracing + "x-geo-country".to_string(), // Geo information from CDN + "x-experiment-id".to_string(), // A/B testing + "ai-client-name".to_string(), // Client identification + ]; + + let mut incoming = HeaderMap::new(); + incoming.insert("x-tenant-id", HeaderValue::from_static("tenant-123")); + incoming.insert("x-trace-id", HeaderValue::from_static("trace-456")); + incoming.insert("x-geo-country", HeaderValue::from_static("US")); + incoming.insert("x-experiment-id", HeaderValue::from_static("exp-789")); + incoming.insert("ai-client-name", HeaderValue::from_static("claude")); + incoming.insert("other-header", HeaderValue::from_static("ignored")); + + let mut outgoing = HeaderMap::new(); + + forward_headers(&names, &incoming, &mut outgoing); + + assert_eq!(outgoing.get("x-tenant-id").unwrap(), "tenant-123"); + assert_eq!(outgoing.get("x-trace-id").unwrap(), "trace-456"); + assert_eq!(outgoing.get("x-geo-country").unwrap(), "US"); + assert_eq!(outgoing.get("x-experiment-id").unwrap(), "exp-789"); + assert_eq!(outgoing.get("ai-client-name").unwrap(), "claude"); + + assert!(outgoing.get("other-header").is_none()); + } + + #[test] + fn test_forward_headers_blocks_hop_by_hop_headers() { + let names = vec!["connection".to_string(), "content-length".to_string()]; + + let mut incoming = HeaderMap::new(); + incoming.insert("connection", HeaderValue::from_static("keep-alive")); + incoming.insert("content-length", HeaderValue::from_static("1234")); + + let mut outgoing = HeaderMap::new(); + + forward_headers(&names, &incoming, &mut outgoing); + + assert!(outgoing.get("connection").is_none()); + assert!(outgoing.get("content-length").is_none()); + } + + #[test] + fn test_forward_headers_case_insensitive_matching() { + let names = vec!["X-Tenant-ID".to_string()]; + + let mut incoming = HeaderMap::new(); + incoming.insert("x-tenant-id", HeaderValue::from_static("tenant-123")); + + let mut outgoing = HeaderMap::new(); + forward_headers(&names, &incoming, &mut outgoing); + + assert_eq!(outgoing.get("x-tenant-id").unwrap(), "tenant-123"); + } +} diff --git a/crates/apollo-mcp-server/src/lib.rs b/crates/apollo-mcp-server/src/lib.rs index 1737b4e1..b4a07e52 100644 --- a/crates/apollo-mcp-server/src/lib.rs +++ b/crates/apollo-mcp-server/src/lib.rs @@ -7,6 +7,7 @@ pub mod errors; pub mod event; mod explorer; mod graphql; +pub mod headers; pub mod health; mod introspection; pub mod json_schema; diff --git a/crates/apollo-mcp-server/src/main.rs b/crates/apollo-mcp-server/src/main.rs index 0d80e937..e08c0309 100644 --- a/crates/apollo-mcp-server/src/main.rs +++ b/crates/apollo-mcp-server/src/main.rs @@ -115,6 +115,7 @@ async fn main() -> anyhow::Result<()> { .endpoint(config.endpoint.into_inner()) .maybe_explorer_graph_ref(explorer_graph_ref) .headers(config.headers) + .forward_headers(config.forward_headers) .execute_introspection(config.introspection.execute.enabled) .validate_introspection(config.introspection.validate.enabled) .introspect_introspection(config.introspection.introspect.enabled) diff --git a/crates/apollo-mcp-server/src/runtime.rs b/crates/apollo-mcp-server/src/runtime.rs index 71a39684..1a6ed528 100644 --- a/crates/apollo-mcp-server/src/runtime.rs +++ b/crates/apollo-mcp-server/src/runtime.rs @@ -240,6 +240,7 @@ mod test { ], }, headers: {}, + forward_headers: [], health_check: HealthCheckConfig { enabled: false, path: "/health", diff --git a/crates/apollo-mcp-server/src/runtime/config.rs b/crates/apollo-mcp-server/src/runtime/config.rs index 598462bd..369fb6a0 100644 --- a/crates/apollo-mcp-server/src/runtime/config.rs +++ b/crates/apollo-mcp-server/src/runtime/config.rs @@ -1,6 +1,8 @@ use std::path::PathBuf; -use apollo_mcp_server::{cors::CorsConfig, health::HealthCheckConfig, server::Transport}; +use apollo_mcp_server::{ + cors::CorsConfig, headers::ForwardHeaders, health::HealthCheckConfig, server::Transport, +}; use reqwest::header::HeaderMap; use schemars::JsonSchema; use serde::Deserialize; @@ -33,6 +35,10 @@ pub struct Config { #[schemars(schema_with = "super::schemas::header_map")] pub headers: HeaderMap, + /// List of header names to forward from MCP client requests to GraphQL requests + #[serde(default)] + pub forward_headers: ForwardHeaders, + /// Health check configuration #[serde(default)] pub health_check: HealthCheckConfig, diff --git a/crates/apollo-mcp-server/src/server.rs b/crates/apollo-mcp-server/src/server.rs index cdbd72e3..1c2735c2 100644 --- a/crates/apollo-mcp-server/src/server.rs +++ b/crates/apollo-mcp-server/src/server.rs @@ -12,6 +12,7 @@ use crate::cors::CorsConfig; use crate::custom_scalar_map::CustomScalarMap; use crate::errors::ServerError; use crate::event::Event as ServerEvent; +use crate::headers::ForwardHeaders; use crate::health::HealthCheckConfig; use crate::operations::{MutationMode, OperationSource}; @@ -26,6 +27,7 @@ pub struct Server { operation_source: OperationSource, endpoint: Url, headers: HeaderMap, + forward_headers: ForwardHeaders, execute_introspection: bool, validate_introspection: bool, introspect_introspection: bool, @@ -111,6 +113,7 @@ impl Server { operation_source: OperationSource, endpoint: Url, headers: HeaderMap, + forward_headers: ForwardHeaders, execute_introspection: bool, validate_introspection: bool, introspect_introspection: bool, @@ -139,6 +142,7 @@ impl Server { operation_source, endpoint, headers, + forward_headers, execute_introspection, validate_introspection, introspect_introspection, diff --git a/crates/apollo-mcp-server/src/server/states.rs b/crates/apollo-mcp-server/src/server/states.rs index c89f3a63..0f38ec6b 100644 --- a/crates/apollo-mcp-server/src/server/states.rs +++ b/crates/apollo-mcp-server/src/server/states.rs @@ -9,6 +9,7 @@ use crate::{ cors::CorsConfig, custom_scalar_map::CustomScalarMap, errors::{OperationError, ServerError}, + headers::ForwardHeaders, health::HealthCheckConfig, operations::MutationMode, }; @@ -34,6 +35,7 @@ struct Config { transport: Transport, endpoint: Url, headers: HeaderMap, + forward_headers: ForwardHeaders, execute_introspection: bool, validate_introspection: bool, introspect_introspection: bool, @@ -68,6 +70,7 @@ impl StateMachine { transport: server.transport, endpoint: server.endpoint, headers: server.headers, + forward_headers: server.forward_headers, execute_introspection: server.execute_introspection, validate_introspection: server.validate_introspection, introspect_introspection: server.introspect_introspection, diff --git a/crates/apollo-mcp-server/src/server/states/running.rs b/crates/apollo-mcp-server/src/server/states/running.rs index 6111e0ce..6ee7d91c 100644 --- a/crates/apollo-mcp-server/src/server/states/running.rs +++ b/crates/apollo-mcp-server/src/server/states/running.rs @@ -1,8 +1,6 @@ -use std::ops::Deref as _; use std::sync::Arc; use apollo_compiler::{Schema, validation::Valid}; -use headers::HeaderMapExt as _; use opentelemetry::trace::FutureExt; use opentelemetry::{Context, KeyValue}; use reqwest::header::HeaderMap; @@ -24,11 +22,11 @@ use url::Url; use crate::generated::telemetry::{TelemetryAttribute, TelemetryMetric}; use crate::meter; use crate::{ - auth::ValidToken, custom_scalar_map::CustomScalarMap, errors::{McpError, ServerError}, explorer::{EXPLORER_TOOL_NAME, Explorer}, graphql::{self, Executable as _}, + headers::{ForwardHeaders, build_request_headers}, health::HealthCheck, introspection::tools::{ execute::{EXECUTE_TOOL_NAME, Execute}, @@ -44,6 +42,7 @@ pub(super) struct Running { pub(super) schema: Arc>>, pub(super) operations: Arc>>, pub(super) headers: HeaderMap, + pub(super) forward_headers: ForwardHeaders, pub(super) endpoint: Url, pub(super) execute_tool: Option, pub(super) introspect_tool: Option, @@ -235,20 +234,19 @@ impl ServerHandler for Running { .await } EXECUTE_TOOL_NAME => { - let mut headers = self.headers.clone(); - if let Some(axum_parts) = context.extensions.get::() { - // Optionally extract the validated token and propagate it to upstream servers if present - if !self.disable_auth_token_passthrough - && let Some(token) = axum_parts.extensions.get::() - { - headers.typed_insert(token.deref().clone()); - } - - // Forward the mcp-session-id header if present - if let Some(session_id) = axum_parts.headers.get("mcp-session-id") { - headers.insert("mcp-session-id", session_id.clone()); - } - } + let headers = if let Some(axum_parts) = + context.extensions.get::() + { + build_request_headers( + &self.headers, + &self.forward_headers, + &axum_parts.headers, + &axum_parts.extensions, + self.disable_auth_token_passthrough, + ) + } else { + self.headers.clone() + }; self.execute_tool .as_ref() @@ -268,20 +266,19 @@ impl ServerHandler for Running { .await } _ => { - let mut headers = self.headers.clone(); - if let Some(axum_parts) = context.extensions.get::() { - // Optionally extract the validated token and propagate it to upstream servers if present - if !self.disable_auth_token_passthrough - && let Some(token) = axum_parts.extensions.get::() - { - headers.typed_insert(token.deref().clone()); - } - - // Also forward the mcp-session-id header if present - if let Some(session_id) = axum_parts.headers.get("mcp-session-id") { - headers.insert("mcp-session-id", session_id.clone()); - } - } + let headers = if let Some(axum_parts) = + context.extensions.get::() + { + build_request_headers( + &self.headers, + &self.forward_headers, + &axum_parts.headers, + &axum_parts.extensions, + self.disable_auth_token_passthrough, + ) + } else { + self.headers.clone() + }; let graphql_request = graphql::Request { input: Value::from(request.arguments.clone()), @@ -408,6 +405,7 @@ mod tests { schema: Arc::new(Mutex::new(schema)), operations: Arc::new(Mutex::new(vec![])), headers: HeaderMap::new(), + forward_headers: vec![], endpoint: "http://localhost:4000".parse().unwrap(), execute_tool: None, introspect_tool: None, diff --git a/crates/apollo-mcp-server/src/server/states/starting.rs b/crates/apollo-mcp-server/src/server/states/starting.rs index c377da5a..f356d876 100644 --- a/crates/apollo-mcp-server/src/server/states/starting.rs +++ b/crates/apollo-mcp-server/src/server/states/starting.rs @@ -140,6 +140,7 @@ impl Starting { schema, operations: Arc::new(Mutex::new(operations)), headers: self.config.headers, + forward_headers: self.config.forward_headers.clone(), endpoint: self.config.endpoint, execute_tool, introspect_tool, @@ -355,6 +356,7 @@ mod tests { mutation_mode: MutationMode::All, execute_introspection: true, headers: HeaderMap::new(), + forward_headers: vec![], validate_introspection: true, introspect_introspection: true, search_introspection: true, diff --git a/docs/source/best-practices.mdx b/docs/source/best-practices.mdx index 6f3ce712..c60d639d 100644 --- a/docs/source/best-practices.mdx +++ b/docs/source/best-practices.mdx @@ -47,3 +47,9 @@ To maintain clear trust boundaries, MCP servers must only accept tokens explicit Forwarding client tokens downstream is not allowed. Apollo MCP Server supports OAuth 2.1 authentication that follows best practices and aligns with the MCP authorization model. See our [authorization guide](/apollo-mcp-server/auth) for implementation details. + +## Avoid forwarding credentials with header forwarding + +The [forward_headers](/apollo-mcp-server/config-file#forward-headers) setting is designed for forwarding **non-sensitive metadata** like A/B testing, feature flagging, geo information from CDNs, or internal instrumentation. + +**We strongly recommend against using header forwarding for passing through credentials** such as API keys or access tokens. Forwarding credentials can introduce confused deputy vulnerabilities where your GraphQL API may incorrectly trust headers as though they were validated by the MCP Server. diff --git a/docs/source/config-file.mdx b/docs/source/config-file.mdx index 54021eaa..4a151d2d 100644 --- a/docs/source/config-file.mdx +++ b/docs/source/config-file.mdx @@ -20,6 +20,7 @@ All fields are optional. | `cors` | `Cors` | | CORS configuration | | `custom_scalars` | `FilePath` | | Path to a [custom scalar map](/apollo-mcp-server/custom-scalars) | | `endpoint` | `URL` | `http://localhost:4000/` | The target GraphQL endpoint | +| `forward_headers`| `List` | `[]` | Headers to forward from MCP clients to GraphQL API | | `graphos` | `GraphOS` | | Apollo-specific credential overrides | | `headers` | `Map` | `{}` | List of hard-coded headers to include in all GraphQL requests | | `health_check` | `HealthCheck` | | Health check configuration | @@ -43,6 +44,19 @@ These fields are under the top-level `graphos` key and define your GraphOS graph | `apollo_registry_url` | `URL` | | The URL to use for Apollo's registry | | `apollo_uplink_endpoints` | `URL` | | List of uplink URL overrides. You can also provide this with the `APOLLO_UPLINK_ENDPOINTS` environment variable | +### Forward Headers + +The `forward_headers` option allows you to forward specific headers from incoming MCP client requests to your GraphQL API. + +This is useful for: +- Multi-tenant applications (forwarding tenant IDs) +- A/B testing (forwarding experiment IDs) +- Geo information (forwarding country Codes) +- Client identification (forwarding AI client names) +- Internal instrumentation (forwarding correlation IDs) + +Supports exact header names only. Hop-by-hop headers (like `connection`, `transfer-encoding`) are automatically blocked for security. + ### CORS These fields are under the top-level `cors` key and configure Cross-Origin Resource Sharing (CORS) for browser-based MCP clients.