From eab17c2c06aea2d216f4e275136583abc49e5b66 Mon Sep 17 00:00:00 2001 From: Rama Palaniappan Date: Tue, 17 Feb 2026 10:56:15 -0800 Subject: [PATCH 1/2] Issue-642 Header Transform --- crates/apollo-mcp-server/src/headers.rs | 284 ++++++++++++++++++ crates/apollo-mcp-server/src/server.rs | 5 +- crates/apollo-mcp-server/src/server/states.rs | 4 +- .../src/server/states/running.rs | 23 +- .../src/server/states/starting.rs | 76 +++++ 5 files changed, 387 insertions(+), 5 deletions(-) diff --git a/crates/apollo-mcp-server/src/headers.rs b/crates/apollo-mcp-server/src/headers.rs index f7aa5d52..73386cc5 100644 --- a/crates/apollo-mcp-server/src/headers.rs +++ b/crates/apollo-mcp-server/src/headers.rs @@ -1,5 +1,6 @@ use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use headers::HeaderMapExt; use http::Extensions; @@ -11,6 +12,14 @@ use crate::auth::ValidToken; /// List of header names to forward from MCP clients to GraphQL API pub type ForwardHeaders = Vec; +/// A callback that transforms HTTP headers before they are sent to the upstream GraphQL endpoint. +/// +/// The callback receives a mutable reference to the fully assembled headers (after static headers, +/// forwarded headers, auth token passthrough, and mcp-session-id have all been applied). +/// This enables custom authentication schemes, header-based routing, HMAC signing, and other +/// transformations without requiring an intermediary proxy. +pub type HeaderTransform = Arc; + /// Build headers for a GraphQL request by combining static headers with forwarded headers pub fn build_request_headers( static_headers: &HeaderMap, @@ -18,6 +27,7 @@ pub fn build_request_headers( incoming_headers: &HeaderMap, extensions: &Extensions, disable_auth_token_passthrough: bool, + header_transform: Option<&HeaderTransform>, ) -> HeaderMap { // Starts with static headers let mut headers = static_headers.clone(); @@ -35,6 +45,11 @@ pub fn build_request_headers( headers.insert("mcp-session-id", session_id.clone()); } + // Apply consumer-provided header transformation + if let Some(transform) = header_transform { + transform(&mut headers); + } + headers } @@ -101,6 +116,7 @@ mod tests { &incoming_headers, &extensions, false, + None, ); assert_eq!(result.get("x-api-key").unwrap(), "static-key"); @@ -125,6 +141,7 @@ mod tests { &incoming_headers, &extensions, false, + None, ); assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123"); @@ -151,6 +168,7 @@ mod tests { &incoming_headers, &extensions, false, + None, ); assert!(result.get("authorization").is_some()); @@ -176,6 +194,7 @@ mod tests { &incoming_headers, &extensions, true, + None, ); assert!(result.get("authorization").is_none()); @@ -197,6 +216,7 @@ mod tests { &incoming_headers, &extensions, false, + None, ); assert_eq!(result.get("mcp-session-id").unwrap(), "session-123"); @@ -234,6 +254,7 @@ mod tests { &incoming_headers, &extensions, false, + None, ); // Verify all parts combined correctly @@ -245,6 +266,269 @@ mod tests { } } + mod header_transform { + use super::*; + + #[test] + fn applies_transform_to_headers() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("static-key")); + + let transform: super::super::HeaderTransform = + std::sync::Arc::new(|headers: &mut HeaderMap| { + headers.insert("x-injected", HeaderValue::from_static("injected-value")); + }); + + let result = super::super::build_request_headers( + &static_headers, + &vec![], + &HeaderMap::new(), + &Extensions::new(), + false, + Some(&transform), + ); + + assert_eq!(result.get("x-api-key").unwrap(), "static-key"); + assert_eq!(result.get("x-injected").unwrap(), "injected-value"); + } + + #[test] + fn transform_can_modify_existing_headers() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("authorization", HeaderValue::from_static("original-auth")); + + let transform: super::super::HeaderTransform = + std::sync::Arc::new(|headers: &mut HeaderMap| { + headers.insert( + "authorization", + HeaderValue::from_static("transformed-auth"), + ); + }); + + let result = super::super::build_request_headers( + &static_headers, + &vec![], + &HeaderMap::new(), + &Extensions::new(), + false, + Some(&transform), + ); + + assert_eq!(result.get("authorization").unwrap(), "transformed-auth"); + } + + #[test] + fn transform_runs_after_all_other_headers() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-static", HeaderValue::from_static("static-value")); + + let forward_header_names = vec!["x-forwarded".to_string()]; + + let mut incoming_headers = HeaderMap::new(); + incoming_headers.insert("x-forwarded", HeaderValue::from_static("forwarded-value")); + incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-abc")); + + let extensions = Extensions::new(); + + let transform: super::super::HeaderTransform = + std::sync::Arc::new(|headers: &mut HeaderMap| { + // Verify all headers are present before transform runs + assert!(headers.contains_key("x-static")); + assert!(headers.contains_key("x-forwarded")); + assert!(headers.contains_key("mcp-session-id")); + headers.insert("x-transform-ran", HeaderValue::from_static("true")); + }); + + let result = super::super::build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + Some(&transform), + ); + + assert_eq!(result.get("x-static").unwrap(), "static-value"); + assert_eq!(result.get("x-forwarded").unwrap(), "forwarded-value"); + assert_eq!(result.get("mcp-session-id").unwrap(), "session-abc"); + assert_eq!(result.get("x-transform-ran").unwrap(), "true"); + } + + #[test] + fn none_transform_is_noop() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("key")); + + let result = super::super::build_request_headers( + &static_headers, + &vec![], + &HeaderMap::new(), + &Extensions::new(), + false, + None, + ); + + assert_eq!(result.get("x-api-key").unwrap(), "key"); + } + + #[test] + fn transform_can_remove_headers() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("static-key")); + static_headers.insert("x-secret", HeaderValue::from_static("should-be-removed")); + + let transform: super::super::HeaderTransform = + std::sync::Arc::new(|headers: &mut HeaderMap| { + headers.remove("x-secret"); + }); + + let result = super::super::build_request_headers( + &static_headers, + &vec![], + &HeaderMap::new(), + &Extensions::new(), + false, + Some(&transform), + ); + + assert_eq!(result.get("x-api-key").unwrap(), "static-key"); + assert!(result.get("x-secret").is_none()); + } + + #[test] + fn transform_can_override_auth_token() { + use headers::Authorization; + + let static_headers = HeaderMap::new(); + let forward_header_names = vec![]; + let incoming_headers = HeaderMap::new(); + + let mut extensions = Extensions::new(); + let token = ValidToken { + token: Authorization::bearer("original-token").unwrap(), + scopes: vec![], + }; + extensions.insert(token); + + let transform: super::super::HeaderTransform = + std::sync::Arc::new(|headers: &mut HeaderMap| { + headers.insert( + "authorization", + HeaderValue::from_static("Bearer custom-signed-token"), + ); + }); + + let result = super::super::build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + Some(&transform), + ); + + // The transform should override the auth token passthrough + assert_eq!( + result.get("authorization").unwrap(), + "Bearer custom-signed-token" + ); + } + + #[test] + fn transform_is_arc_cloneable_and_reusable() { + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("key")); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let transform: super::super::HeaderTransform = + std::sync::Arc::new(move |headers: &mut HeaderMap| { + call_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + headers.insert("x-request-id", HeaderValue::from_static("unique")); + }); + + // Clone the Arc to simulate sharing across threads/invocations + let transform_clone = transform.clone(); + + let result1 = super::super::build_request_headers( + &static_headers, + &vec![], + &HeaderMap::new(), + &Extensions::new(), + false, + Some(&transform), + ); + + let result2 = super::super::build_request_headers( + &static_headers, + &vec![], + &HeaderMap::new(), + &Extensions::new(), + false, + Some(&transform_clone), + ); + + assert_eq!(result1.get("x-request-id").unwrap(), "unique"); + assert_eq!(result2.get("x-request-id").unwrap(), "unique"); + assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2); + } + + #[test] + fn combined_scenario_with_transform() { + use headers::Authorization; + + // Static headers + let mut static_headers = HeaderMap::new(); + static_headers.insert("x-api-key", HeaderValue::from_static("static-key")); + + // Forward specific headers + let forward_header_names = vec!["x-tenant-id".to_string()]; + + // Incoming headers + let mut incoming_headers = HeaderMap::new(); + incoming_headers.insert("x-tenant-id", HeaderValue::from_static("tenant-123")); + incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-456")); + + // OAuth token + let mut extensions = Extensions::new(); + let token = ValidToken { + token: Authorization::bearer("oauth-token").unwrap(), + scopes: vec![], + }; + extensions.insert(token); + + // Transform that adds HMAC signature based on all other headers + let transform: super::super::HeaderTransform = + std::sync::Arc::new(|headers: &mut HeaderMap| { + // Simulate HMAC signing: the transform can read existing headers + let has_api_key = headers.contains_key("x-api-key"); + let has_tenant = headers.contains_key("x-tenant-id"); + let has_session = headers.contains_key("mcp-session-id"); + let has_auth = headers.contains_key("authorization"); + assert!(has_api_key && has_tenant && has_session && has_auth); + headers.insert("x-hmac-signature", HeaderValue::from_static("sig-abc123")); + }); + + let result = super::super::build_request_headers( + &static_headers, + &forward_header_names, + &incoming_headers, + &extensions, + false, + Some(&transform), + ); + + // All original headers are present + assert_eq!(result.get("x-api-key").unwrap(), "static-key"); + assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123"); + assert_eq!(result.get("mcp-session-id").unwrap(), "session-456"); + assert_eq!(result.get("authorization").unwrap(), "Bearer oauth-token"); + // Plus the transform-injected header + assert_eq!(result.get("x-hmac-signature").unwrap(), "sig-abc123"); + } + } + mod forward_headers { use super::*; use tracing_test::traced_test; diff --git a/crates/apollo-mcp-server/src/server.rs b/crates/apollo-mcp-server/src/server.rs index 3eeee689..7719dd3c 100644 --- a/crates/apollo-mcp-server/src/server.rs +++ b/crates/apollo-mcp-server/src/server.rs @@ -12,7 +12,7 @@ use crate::cors::CorsConfig; use crate::custom_scalar_map::CustomScalarMap; use crate::errors::ServerError; use crate::event::Event as ServerEvent; -use crate::headers::ForwardHeaders; +use crate::headers::{ForwardHeaders, HeaderTransform}; use crate::health::HealthCheckConfig; use crate::host_validation::HostValidationConfig; use crate::operations::{MutationMode, OperationSource}; @@ -52,6 +52,7 @@ pub struct Server { health_check: HealthCheckConfig, cors: CorsConfig, server_info: ServerInfoConfig, + header_transform: Option, } #[derive(Debug, Clone, Deserialize, Default, JsonSchema)] @@ -149,6 +150,7 @@ impl Server { health_check: HealthCheckConfig, cors: CorsConfig, server_info: ServerInfoConfig, + header_transform: Option, ) -> Self { let headers = { let mut headers = headers.clone(); @@ -184,6 +186,7 @@ impl Server { health_check, cors, server_info, + header_transform, } } diff --git a/crates/apollo-mcp-server/src/server/states.rs b/crates/apollo-mcp-server/src/server/states.rs index 6bbac409..ccafc025 100644 --- a/crates/apollo-mcp-server/src/server/states.rs +++ b/crates/apollo-mcp-server/src/server/states.rs @@ -9,7 +9,7 @@ use crate::{ cors::CorsConfig, custom_scalar_map::CustomScalarMap, errors::{OperationError, ServerError}, - headers::ForwardHeaders, + headers::{ForwardHeaders, HeaderTransform}, health::HealthCheckConfig, operations::MutationMode, server_info::ServerInfoConfig, @@ -60,6 +60,7 @@ struct Config { health_check: HealthCheckConfig, cors: CorsConfig, server_info: ServerInfoConfig, + header_transform: Option, } impl StateMachine { @@ -101,6 +102,7 @@ impl StateMachine { health_check: server.health_check, cors: server.cors, server_info: server.server_info, + header_transform: server.header_transform, }, }); diff --git a/crates/apollo-mcp-server/src/server/states/running.rs b/crates/apollo-mcp-server/src/server/states/running.rs index 03ee3b63..2a9fa6d5 100644 --- a/crates/apollo-mcp-server/src/server/states/running.rs +++ b/crates/apollo-mcp-server/src/server/states/running.rs @@ -34,7 +34,7 @@ use crate::{ custom_scalar_map::CustomScalarMap, errors::McpError, explorer::{EXPLORER_TOOL_NAME, Explorer}, - headers::{ForwardHeaders, build_request_headers}, + headers::{ForwardHeaders, HeaderTransform, build_request_headers}, health::HealthCheck, introspection::tools::{ execute::{EXECUTE_TOOL_NAME, Execute}, @@ -66,6 +66,7 @@ pub(super) struct Running { pub(super) disable_schema_description: bool, pub(super) enable_output_schema: bool, pub(super) disable_auth_token_passthrough: bool, + pub(super) header_transform: Option, pub(super) health_check: Option, pub(super) server_info: ServerInfoConfig, } @@ -382,9 +383,14 @@ impl ServerHandler for Running { &axum_parts.headers, &axum_parts.extensions, self.disable_auth_token_passthrough, + self.header_transform.as_ref(), ) } else { - self.headers.clone() + let mut headers = self.headers.clone(); + if let Some(transform) = &self.header_transform { + transform(&mut headers); + } + headers }; execute_operation( @@ -412,9 +418,14 @@ impl ServerHandler for Running { &axum_parts.headers, &axum_parts.extensions, self.disable_auth_token_passthrough, + self.header_transform.as_ref(), ) } else { - self.headers.clone() + let mut headers = self.headers.clone(); + if let Some(transform) = &self.header_transform { + transform(&mut headers); + } + headers }; // Access the "app" query parameter from the HTTP request @@ -585,6 +596,7 @@ mod tests { disable_schema_description: false, enable_output_schema: false, disable_auth_token_passthrough: false, + header_transform: None, health_check: None, server_info: ServerInfoConfig::default(), }; @@ -645,6 +657,7 @@ mod tests { disable_schema_description: false, enable_output_schema: false, disable_auth_token_passthrough: false, + header_transform: None, health_check: None, server_info: ServerInfoConfig::default(), }; @@ -723,6 +736,7 @@ mod tests { disable_schema_description: false, enable_output_schema: false, disable_auth_token_passthrough: false, + header_transform: None, health_check: None, server_info: ServerInfoConfig::default(), } @@ -1392,6 +1406,7 @@ mod tests { disable_schema_description: false, enable_output_schema: false, disable_auth_token_passthrough: false, + header_transform: None, health_check: None, server_info: ServerInfoConfig::default(), }; @@ -1445,6 +1460,7 @@ mod tests { disable_schema_description: false, enable_output_schema: false, disable_auth_token_passthrough: false, + header_transform: None, health_check: None, server_info: custom_config, }; @@ -1499,6 +1515,7 @@ mod tests { disable_schema_description: false, enable_output_schema: false, disable_auth_token_passthrough: false, + header_transform: None, health_check: None, server_info: Default::default(), } diff --git a/crates/apollo-mcp-server/src/server/states/starting.rs b/crates/apollo-mcp-server/src/server/states/starting.rs index c299e947..b82a7ba4 100644 --- a/crates/apollo-mcp-server/src/server/states/starting.rs +++ b/crates/apollo-mcp-server/src/server/states/starting.rs @@ -166,6 +166,7 @@ impl Starting { disable_schema_description: self.config.disable_schema_description, enable_output_schema: self.config.enable_output_schema, disable_auth_token_passthrough: self.config.disable_auth_token_passthrough, + header_transform: self.config.header_transform.clone(), health_check: health_check.clone(), server_info: self.config.server_info.clone(), }; @@ -340,6 +341,7 @@ mod tests { }, cors: Default::default(), server_info: Default::default(), + header_transform: None, }, schema: Schema::parse_and_validate("type Query { hello: String }", "test.graphql") .expect("Valid schema"), @@ -349,6 +351,79 @@ mod tests { assert!(running.await.is_ok()); } + #[tokio::test] + async fn start_server_with_header_transform() { + let transform: crate::headers::HeaderTransform = + std::sync::Arc::new(|headers: &mut reqwest::header::HeaderMap| { + headers.insert( + "x-custom-auth", + reqwest::header::HeaderValue::from_static("signed-value"), + ); + }); + + let starting = Starting { + config: Config { + transport: Transport::StreamableHttp { + auth: None, + address: "127.0.0.1".parse().unwrap(), + port: 7797, + stateful_mode: false, + host_validation: HostValidationConfig::default(), + }, + endpoint: Url::parse("http://localhost:4000").expect("valid url"), + mutation_mode: MutationMode::All, + execute_introspection: true, + headers: HeaderMap::new(), + forward_headers: vec![], + validate_introspection: true, + introspect_introspection: true, + search_introspection: true, + introspect_minify: false, + search_minify: false, + execute_tool_hint: None, + introspect_tool_hint: None, + search_tool_hint: None, + validate_tool_hint: None, + explorer_graph_ref: None, + custom_scalar_map: None, + disable_type_description: false, + disable_schema_description: false, + enable_output_schema: false, + disable_auth_token_passthrough: false, + search_leaf_depth: 5, + index_memory_bytes: 1024 * 1024 * 1024, + health_check: HealthCheckConfig { + enabled: true, + ..Default::default() + }, + cors: Default::default(), + server_info: Default::default(), + header_transform: Some(transform), + }, + schema: Schema::parse_and_validate("type Query { hello: String }", "test.graphql") + .expect("Valid schema"), + operations: vec![], + }; + let running = starting.start().await; + assert!(running.is_ok()); + + // Verify the header_transform was propagated to the Running state + let running = running.unwrap(); + assert!( + running.header_transform.is_some(), + "header_transform should propagate from Starting config to Running state" + ); + + // Verify the transform works correctly when applied + let mut test_headers = HeaderMap::new(); + (running.header_transform.as_ref().unwrap())(&mut test_headers); + assert_eq!( + test_headers.get("x-custom-auth").unwrap(), + "signed-value", + "propagated transform should be functional" + ); + } + #[tokio::test] async fn start_sse_server_returns_unsupported_error() { let starting = Starting { @@ -383,6 +458,7 @@ mod tests { health_check: HealthCheckConfig::default(), cors: Default::default(), server_info: Default::default(), + header_transform: None, }, schema: Schema::parse_and_validate("type Query { hello: String }", "test.graphql") .expect("Valid schema"), From 5ba04bd42ff5545c4aaf051a8d053c81e2f98740 Mon Sep 17 00:00:00 2001 From: Rama Palaniappan Date: Wed, 18 Feb 2026 10:47:12 -0800 Subject: [PATCH 2/2] chore: add changeset for header transform feature Co-authored-by: Cursor --- .changeset/header_transform.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .changeset/header_transform.md diff --git a/.changeset/header_transform.md b/.changeset/header_transform.md new file mode 100644 index 00000000..c80901bb --- /dev/null +++ b/.changeset/header_transform.md @@ -0,0 +1,7 @@ +--- +default: minor +--- + +# Add header transform support + +Add a `HeaderTransform` callback that allows consumers to modify HTTP headers before they are sent to the upstream GraphQL endpoint. The callback runs after all other header processing (static headers, forwarded headers, auth token passthrough, and mcp-session-id), enabling custom authentication schemes, header-based routing, HMAC signing, and other transformations without requiring an intermediary proxy.