Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/header_transform.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
default: minor
---

# Add header transform support

Add a `HeaderTransform` callback that allows consumers to modify HTTP headers before they are sent to the upstream GraphQL endpoint. The callback runs after all other header processing (static headers, forwarded headers, auth token passthrough, and mcp-session-id), enabling custom authentication schemes, header-based routing, HMAC signing, and other transformations without requiring an intermediary proxy.
284 changes: 284 additions & 0 deletions crates/apollo-mcp-server/src/headers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ops::Deref;
use std::str::FromStr;
use std::sync::Arc;

use headers::HeaderMapExt;
use http::Extensions;
Expand All @@ -11,13 +12,22 @@ use crate::auth::ValidToken;
/// List of header names to forward from MCP clients to GraphQL API
pub type ForwardHeaders = Vec<String>;

/// A callback that transforms HTTP headers before they are sent to the upstream GraphQL endpoint.
///
/// The callback receives a mutable reference to the fully assembled headers (after static headers,
/// forwarded headers, auth token passthrough, and mcp-session-id have all been applied).
/// This enables custom authentication schemes, header-based routing, HMAC signing, and other
/// transformations without requiring an intermediary proxy.
pub type HeaderTransform = Arc<dyn Fn(&mut HeaderMap) + Send + Sync>;

/// Build headers for a GraphQL request by combining static headers with forwarded headers
pub fn build_request_headers(
static_headers: &HeaderMap,
forward_header_names: &ForwardHeaders,
incoming_headers: &HeaderMap,
extensions: &Extensions,
disable_auth_token_passthrough: bool,
header_transform: Option<&HeaderTransform>,
) -> HeaderMap {
// Starts with static headers
let mut headers = static_headers.clone();
Expand All @@ -35,6 +45,11 @@ pub fn build_request_headers(
headers.insert("mcp-session-id", session_id.clone());
}

// Apply consumer-provided header transformation
if let Some(transform) = header_transform {
transform(&mut headers);
}

headers
}

Expand Down Expand Up @@ -101,6 +116,7 @@ mod tests {
&incoming_headers,
&extensions,
false,
None,
);

assert_eq!(result.get("x-api-key").unwrap(), "static-key");
Expand All @@ -125,6 +141,7 @@ mod tests {
&incoming_headers,
&extensions,
false,
None,
);

assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123");
Expand All @@ -151,6 +168,7 @@ mod tests {
&incoming_headers,
&extensions,
false,
None,
);

assert!(result.get("authorization").is_some());
Expand All @@ -176,6 +194,7 @@ mod tests {
&incoming_headers,
&extensions,
true,
None,
);

assert!(result.get("authorization").is_none());
Expand All @@ -197,6 +216,7 @@ mod tests {
&incoming_headers,
&extensions,
false,
None,
);

assert_eq!(result.get("mcp-session-id").unwrap(), "session-123");
Expand Down Expand Up @@ -234,6 +254,7 @@ mod tests {
&incoming_headers,
&extensions,
false,
None,
);

// Verify all parts combined correctly
Expand All @@ -245,6 +266,269 @@ mod tests {
}
}

mod header_transform {
use super::*;

#[test]
fn applies_transform_to_headers() {
let mut static_headers = HeaderMap::new();
static_headers.insert("x-api-key", HeaderValue::from_static("static-key"));

let transform: super::super::HeaderTransform =
std::sync::Arc::new(|headers: &mut HeaderMap| {
headers.insert("x-injected", HeaderValue::from_static("injected-value"));
});

let result = super::super::build_request_headers(
&static_headers,
&vec![],
&HeaderMap::new(),
&Extensions::new(),
false,
Some(&transform),
);

assert_eq!(result.get("x-api-key").unwrap(), "static-key");
assert_eq!(result.get("x-injected").unwrap(), "injected-value");
}

#[test]
fn transform_can_modify_existing_headers() {
let mut static_headers = HeaderMap::new();
static_headers.insert("authorization", HeaderValue::from_static("original-auth"));

let transform: super::super::HeaderTransform =
std::sync::Arc::new(|headers: &mut HeaderMap| {
headers.insert(
"authorization",
HeaderValue::from_static("transformed-auth"),
);
});

let result = super::super::build_request_headers(
&static_headers,
&vec![],
&HeaderMap::new(),
&Extensions::new(),
false,
Some(&transform),
);

assert_eq!(result.get("authorization").unwrap(), "transformed-auth");
}

#[test]
fn transform_runs_after_all_other_headers() {
let mut static_headers = HeaderMap::new();
static_headers.insert("x-static", HeaderValue::from_static("static-value"));

let forward_header_names = vec!["x-forwarded".to_string()];

let mut incoming_headers = HeaderMap::new();
incoming_headers.insert("x-forwarded", HeaderValue::from_static("forwarded-value"));
incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-abc"));

let extensions = Extensions::new();

let transform: super::super::HeaderTransform =
std::sync::Arc::new(|headers: &mut HeaderMap| {
// Verify all headers are present before transform runs
assert!(headers.contains_key("x-static"));
assert!(headers.contains_key("x-forwarded"));
assert!(headers.contains_key("mcp-session-id"));
headers.insert("x-transform-ran", HeaderValue::from_static("true"));
});

let result = super::super::build_request_headers(
&static_headers,
&forward_header_names,
&incoming_headers,
&extensions,
false,
Some(&transform),
);

assert_eq!(result.get("x-static").unwrap(), "static-value");
assert_eq!(result.get("x-forwarded").unwrap(), "forwarded-value");
assert_eq!(result.get("mcp-session-id").unwrap(), "session-abc");
assert_eq!(result.get("x-transform-ran").unwrap(), "true");
}

#[test]
fn none_transform_is_noop() {
let mut static_headers = HeaderMap::new();
static_headers.insert("x-api-key", HeaderValue::from_static("key"));

let result = super::super::build_request_headers(
&static_headers,
&vec![],
&HeaderMap::new(),
&Extensions::new(),
false,
None,
);

assert_eq!(result.get("x-api-key").unwrap(), "key");
}

#[test]
fn transform_can_remove_headers() {
let mut static_headers = HeaderMap::new();
static_headers.insert("x-api-key", HeaderValue::from_static("static-key"));
static_headers.insert("x-secret", HeaderValue::from_static("should-be-removed"));

let transform: super::super::HeaderTransform =
std::sync::Arc::new(|headers: &mut HeaderMap| {
headers.remove("x-secret");
});

let result = super::super::build_request_headers(
&static_headers,
&vec![],
&HeaderMap::new(),
&Extensions::new(),
false,
Some(&transform),
);

assert_eq!(result.get("x-api-key").unwrap(), "static-key");
assert!(result.get("x-secret").is_none());
}

#[test]
fn transform_can_override_auth_token() {
use headers::Authorization;

let static_headers = HeaderMap::new();
let forward_header_names = vec![];
let incoming_headers = HeaderMap::new();

let mut extensions = Extensions::new();
let token = ValidToken {
token: Authorization::bearer("original-token").unwrap(),
scopes: vec![],
};
extensions.insert(token);

let transform: super::super::HeaderTransform =
std::sync::Arc::new(|headers: &mut HeaderMap| {
headers.insert(
"authorization",
HeaderValue::from_static("Bearer custom-signed-token"),
);
});

let result = super::super::build_request_headers(
&static_headers,
&forward_header_names,
&incoming_headers,
&extensions,
false,
Some(&transform),
);

// The transform should override the auth token passthrough
assert_eq!(
result.get("authorization").unwrap(),
"Bearer custom-signed-token"
);
}

#[test]
fn transform_is_arc_cloneable_and_reusable() {
let mut static_headers = HeaderMap::new();
static_headers.insert("x-api-key", HeaderValue::from_static("key"));

let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let call_count_clone = call_count.clone();

let transform: super::super::HeaderTransform =
std::sync::Arc::new(move |headers: &mut HeaderMap| {
call_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
headers.insert("x-request-id", HeaderValue::from_static("unique"));
});

// Clone the Arc to simulate sharing across threads/invocations
let transform_clone = transform.clone();

let result1 = super::super::build_request_headers(
&static_headers,
&vec![],
&HeaderMap::new(),
&Extensions::new(),
false,
Some(&transform),
);

let result2 = super::super::build_request_headers(
&static_headers,
&vec![],
&HeaderMap::new(),
&Extensions::new(),
false,
Some(&transform_clone),
);

assert_eq!(result1.get("x-request-id").unwrap(), "unique");
assert_eq!(result2.get("x-request-id").unwrap(), "unique");
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}

#[test]
fn combined_scenario_with_transform() {
use headers::Authorization;

// Static headers
let mut static_headers = HeaderMap::new();
static_headers.insert("x-api-key", HeaderValue::from_static("static-key"));

// Forward specific headers
let forward_header_names = vec!["x-tenant-id".to_string()];

// Incoming headers
let mut incoming_headers = HeaderMap::new();
incoming_headers.insert("x-tenant-id", HeaderValue::from_static("tenant-123"));
incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-456"));

// OAuth token
let mut extensions = Extensions::new();
let token = ValidToken {
token: Authorization::bearer("oauth-token").unwrap(),
scopes: vec![],
};
extensions.insert(token);

// Transform that adds HMAC signature based on all other headers
let transform: super::super::HeaderTransform =
std::sync::Arc::new(|headers: &mut HeaderMap| {
// Simulate HMAC signing: the transform can read existing headers
let has_api_key = headers.contains_key("x-api-key");
let has_tenant = headers.contains_key("x-tenant-id");
let has_session = headers.contains_key("mcp-session-id");
let has_auth = headers.contains_key("authorization");
assert!(has_api_key && has_tenant && has_session && has_auth);
headers.insert("x-hmac-signature", HeaderValue::from_static("sig-abc123"));
});

let result = super::super::build_request_headers(
&static_headers,
&forward_header_names,
&incoming_headers,
&extensions,
false,
Some(&transform),
);

// All original headers are present
assert_eq!(result.get("x-api-key").unwrap(), "static-key");
assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123");
assert_eq!(result.get("mcp-session-id").unwrap(), "session-456");
assert_eq!(result.get("authorization").unwrap(), "Bearer oauth-token");
// Plus the transform-injected header
assert_eq!(result.get("x-hmac-signature").unwrap(), "sig-abc123");
}
}

mod forward_headers {
use super::*;
use tracing_test::traced_test;
Expand Down
Loading