Skip to content

Commit af46f69

Browse files
committed
feat: Add support for forwarding headers from MCP clients to GraphQL APIs
1 parent a33d5c0 commit af46f69

File tree

8 files changed

+189
-32
lines changed

8 files changed

+189
-32
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
use std::ops::Deref;
2+
use std::str::FromStr;
3+
4+
use headers::HeaderMapExt;
5+
use http::Extensions;
6+
use reqwest::header::{HeaderMap, HeaderName};
7+
8+
use crate::auth::ValidToken;
9+
10+
/// List of header names to forward from MCP clients to GraphQL API
11+
pub type ForwardHeaders = Vec<String>;
12+
13+
/// Build headers for a GraphQL request by combining static headers with forwarded headers
14+
pub fn build_request_headers(
15+
static_headers: &HeaderMap,
16+
forward_header_names: &ForwardHeaders,
17+
incoming_headers: &HeaderMap,
18+
extensions: &Extensions,
19+
disable_auth_token_passthrough: bool,
20+
) -> HeaderMap {
21+
// Starts with static headers
22+
let mut headers = static_headers.clone();
23+
24+
// Forward headers dynamically
25+
forward_headers(forward_header_names, incoming_headers, &mut headers);
26+
27+
// Optionally extract the validated token and propagate it to upstream servers if present
28+
if !disable_auth_token_passthrough && let Some(token) = extensions.get::<ValidToken>() {
29+
headers.typed_insert(token.deref().clone());
30+
}
31+
32+
// Forward the mcp-session-id header if present
33+
if let Some(session_id) = incoming_headers.get("mcp-session-id") {
34+
headers.insert("mcp-session-id", session_id.clone());
35+
}
36+
37+
headers
38+
}
39+
40+
/// Forward matching headers from incoming headers to outgoing headers
41+
fn forward_headers(names: &[String], incoming: &HeaderMap, outgoing: &mut HeaderMap) {
42+
for header in names {
43+
if let Ok(header_name) = HeaderName::from_str(header)
44+
&& let Some(value) = incoming.get(&header_name)
45+
// Hop-by-hop headers are blocked per RFC 7230: https://datatracker.ietf.org/doc/html/rfc7230#section-6.1
46+
&& !matches!(
47+
header_name.as_str().to_lowercase().as_str(),
48+
"connection"
49+
| "keep-alive"
50+
| "proxy-authenticate"
51+
| "proxy-authorization"
52+
| "te"
53+
| "trailers"
54+
| "transfer-encoding"
55+
| "upgrade"
56+
| "content-length"
57+
)
58+
{
59+
outgoing.insert(header_name, value.clone());
60+
}
61+
}
62+
}
63+
64+
#[cfg(test)]
65+
mod tests {
66+
use super::*;
67+
use reqwest::header::HeaderValue;
68+
69+
#[test]
70+
fn test_forward_no_headers_by_default() {
71+
let names: Vec<String> = vec![];
72+
73+
let mut incoming = HeaderMap::new();
74+
incoming.insert("x-tenant-id", HeaderValue::from_static("tenant-123"));
75+
76+
let mut outgoing = HeaderMap::new();
77+
78+
forward_headers(&names, &incoming, &mut outgoing);
79+
80+
assert!(outgoing.is_empty());
81+
}
82+
83+
#[test]
84+
fn test_forward_only_allowed_headers() {
85+
let names = vec![
86+
"x-tenant-id".to_string(), // Multi-tenancy
87+
"x-trace-id".to_string(), // Distributed tracing
88+
"x-geo-country".to_string(), // Geo information from CDN
89+
"x-experiment-id".to_string(), // A/B testing
90+
"ai-client-name".to_string(), // Client identification
91+
];
92+
93+
let mut incoming = HeaderMap::new();
94+
incoming.insert("x-tenant-id", HeaderValue::from_static("tenant-123"));
95+
incoming.insert("x-trace-id", HeaderValue::from_static("trace-456"));
96+
incoming.insert("x-geo-country", HeaderValue::from_static("US"));
97+
incoming.insert("x-experiment-id", HeaderValue::from_static("exp-789"));
98+
incoming.insert("ai-client-name", HeaderValue::from_static("claude"));
99+
incoming.insert("other-header", HeaderValue::from_static("ignored"));
100+
101+
let mut outgoing = HeaderMap::new();
102+
103+
forward_headers(&names, &incoming, &mut outgoing);
104+
105+
assert_eq!(outgoing.get("x-tenant-id").unwrap(), "tenant-123");
106+
assert_eq!(outgoing.get("x-trace-id").unwrap(), "trace-456");
107+
assert_eq!(outgoing.get("x-geo-country").unwrap(), "US");
108+
assert_eq!(outgoing.get("x-experiment-id").unwrap(), "exp-789");
109+
assert_eq!(outgoing.get("ai-client-name").unwrap(), "claude");
110+
111+
assert!(outgoing.get("other-header").is_none());
112+
}
113+
114+
#[test]
115+
fn test_hop_by_hop_headers_blocked() {
116+
let names = vec!["connection".to_string(), "content-length".to_string()];
117+
118+
let mut incoming = HeaderMap::new();
119+
incoming.insert("connection", HeaderValue::from_static("keep-alive"));
120+
incoming.insert("content-length", HeaderValue::from_static("1234"));
121+
122+
let mut outgoing = HeaderMap::new();
123+
124+
forward_headers(&names, &incoming, &mut outgoing);
125+
126+
assert!(outgoing.get("connection").is_none());
127+
assert!(outgoing.get("content-length").is_none());
128+
}
129+
130+
#[test]
131+
fn test_case_insensitive_matching() {
132+
let names = vec!["X-Tenant-ID".to_string()];
133+
134+
let mut incoming = HeaderMap::new();
135+
incoming.insert("x-tenant-id", HeaderValue::from_static("tenant-123"));
136+
137+
let mut outgoing = HeaderMap::new();
138+
forward_headers(&names, &incoming, &mut outgoing);
139+
140+
assert_eq!(outgoing.get("x-tenant-id").unwrap(), "tenant-123");
141+
}
142+
}

crates/apollo-mcp-server/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod custom_scalar_map;
66
pub mod errors;
77
pub mod event;
88
mod explorer;
9+
pub mod headers;
910
mod graphql;
1011
pub mod health;
1112
mod introspection;

crates/apollo-mcp-server/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ async fn main() -> anyhow::Result<()> {
115115
.endpoint(config.endpoint.into_inner())
116116
.maybe_explorer_graph_ref(explorer_graph_ref)
117117
.headers(config.headers)
118+
.forward_headers(config.forward_headers)
118119
.execute_introspection(config.introspection.execute.enabled)
119120
.validate_introspection(config.introspection.validate.enabled)
120121
.introspect_introspection(config.introspection.introspect.enabled)

crates/apollo-mcp-server/src/runtime/config.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::path::PathBuf;
22

3-
use apollo_mcp_server::{cors::CorsConfig, health::HealthCheckConfig, server::Transport};
3+
use apollo_mcp_server::{
4+
cors::CorsConfig, headers::ForwardHeaders, health::HealthCheckConfig, server::Transport,
5+
};
46
use reqwest::header::HeaderMap;
57
use schemars::JsonSchema;
68
use serde::Deserialize;
@@ -33,6 +35,10 @@ pub struct Config {
3335
#[schemars(schema_with = "super::schemas::header_map")]
3436
pub headers: HeaderMap,
3537

38+
/// List of header names to forward from MCP client requests to GraphQL requests
39+
#[serde(default)]
40+
pub forward_headers: ForwardHeaders,
41+
3642
/// Health check configuration
3743
#[serde(default)]
3844
pub health_check: HealthCheckConfig,

crates/apollo-mcp-server/src/server.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::cors::CorsConfig;
1212
use crate::custom_scalar_map::CustomScalarMap;
1313
use crate::errors::ServerError;
1414
use crate::event::Event as ServerEvent;
15+
use crate::headers::ForwardHeaders;
1516
use crate::health::HealthCheckConfig;
1617
use crate::operations::{MutationMode, OperationSource};
1718

@@ -26,6 +27,7 @@ pub struct Server {
2627
operation_source: OperationSource,
2728
endpoint: Url,
2829
headers: HeaderMap,
30+
forward_headers: ForwardHeaders,
2931
execute_introspection: bool,
3032
validate_introspection: bool,
3133
introspect_introspection: bool,
@@ -111,6 +113,7 @@ impl Server {
111113
operation_source: OperationSource,
112114
endpoint: Url,
113115
headers: HeaderMap,
116+
forward_headers: ForwardHeaders,
114117
execute_introspection: bool,
115118
validate_introspection: bool,
116119
introspect_introspection: bool,
@@ -139,6 +142,7 @@ impl Server {
139142
operation_source,
140143
endpoint,
141144
headers,
145+
forward_headers,
142146
execute_introspection,
143147
validate_introspection,
144148
introspect_introspection,

crates/apollo-mcp-server/src/server/states.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::{
99
cors::CorsConfig,
1010
custom_scalar_map::CustomScalarMap,
1111
errors::{OperationError, ServerError},
12+
headers::ForwardHeaders,
1213
health::HealthCheckConfig,
1314
operations::MutationMode,
1415
};
@@ -34,6 +35,7 @@ struct Config {
3435
transport: Transport,
3536
endpoint: Url,
3637
headers: HeaderMap,
38+
forward_headers: ForwardHeaders,
3739
execute_introspection: bool,
3840
validate_introspection: bool,
3941
introspect_introspection: bool,
@@ -68,6 +70,7 @@ impl StateMachine {
6870
transport: server.transport,
6971
endpoint: server.endpoint,
7072
headers: server.headers,
73+
forward_headers: server.forward_headers,
7174
execute_introspection: server.execute_introspection,
7275
validate_introspection: server.validate_introspection,
7376
introspect_introspection: server.introspect_introspection,

crates/apollo-mcp-server/src/server/states/running.rs

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
use std::ops::Deref as _;
21
use std::sync::Arc;
32

43
use apollo_compiler::{Schema, validation::Valid};
5-
use headers::HeaderMapExt as _;
64
use opentelemetry::trace::FutureExt;
75
use opentelemetry::{Context, KeyValue};
86
use reqwest::header::HeaderMap;
@@ -24,10 +22,10 @@ use url::Url;
2422
use crate::generated::telemetry::{TelemetryAttribute, TelemetryMetric};
2523
use crate::meter;
2624
use crate::{
27-
auth::ValidToken,
2825
custom_scalar_map::CustomScalarMap,
2926
errors::{McpError, ServerError},
3027
explorer::{EXPLORER_TOOL_NAME, Explorer},
28+
headers::{ForwardHeaders, build_request_headers},
3129
graphql::{self, Executable as _},
3230
health::HealthCheck,
3331
introspection::tools::{
@@ -44,6 +42,7 @@ pub(super) struct Running {
4442
pub(super) schema: Arc<Mutex<Valid<Schema>>>,
4543
pub(super) operations: Arc<Mutex<Vec<Operation>>>,
4644
pub(super) headers: HeaderMap,
45+
pub(super) forward_headers: ForwardHeaders,
4746
pub(super) endpoint: Url,
4847
pub(super) execute_tool: Option<Execute>,
4948
pub(super) introspect_tool: Option<Introspect>,
@@ -235,20 +234,19 @@ impl ServerHandler for Running {
235234
.await
236235
}
237236
EXECUTE_TOOL_NAME => {
238-
let mut headers = self.headers.clone();
239-
if let Some(axum_parts) = context.extensions.get::<axum::http::request::Parts>() {
240-
// Optionally extract the validated token and propagate it to upstream servers if present
241-
if !self.disable_auth_token_passthrough
242-
&& let Some(token) = axum_parts.extensions.get::<ValidToken>()
243-
{
244-
headers.typed_insert(token.deref().clone());
245-
}
246-
247-
// Forward the mcp-session-id header if present
248-
if let Some(session_id) = axum_parts.headers.get("mcp-session-id") {
249-
headers.insert("mcp-session-id", session_id.clone());
250-
}
251-
}
237+
let headers = if let Some(axum_parts) =
238+
context.extensions.get::<axum::http::request::Parts>()
239+
{
240+
build_request_headers(
241+
&self.headers,
242+
&self.forward_headers,
243+
&axum_parts.headers,
244+
&axum_parts.extensions,
245+
self.disable_auth_token_passthrough,
246+
)
247+
} else {
248+
self.headers.clone()
249+
};
252250

253251
self.execute_tool
254252
.as_ref()
@@ -268,20 +266,19 @@ impl ServerHandler for Running {
268266
.await
269267
}
270268
_ => {
271-
let mut headers = self.headers.clone();
272-
if let Some(axum_parts) = context.extensions.get::<axum::http::request::Parts>() {
273-
// Optionally extract the validated token and propagate it to upstream servers if present
274-
if !self.disable_auth_token_passthrough
275-
&& let Some(token) = axum_parts.extensions.get::<ValidToken>()
276-
{
277-
headers.typed_insert(token.deref().clone());
278-
}
279-
280-
// Also forward the mcp-session-id header if present
281-
if let Some(session_id) = axum_parts.headers.get("mcp-session-id") {
282-
headers.insert("mcp-session-id", session_id.clone());
283-
}
284-
}
269+
let headers = if let Some(axum_parts) =
270+
context.extensions.get::<axum::http::request::Parts>()
271+
{
272+
build_request_headers(
273+
&self.headers,
274+
&self.forward_headers,
275+
&axum_parts.headers,
276+
&axum_parts.extensions,
277+
self.disable_auth_token_passthrough,
278+
)
279+
} else {
280+
self.headers.clone()
281+
};
285282

286283
let graphql_request = graphql::Request {
287284
input: Value::from(request.arguments.clone()),
@@ -408,6 +405,7 @@ mod tests {
408405
schema: Arc::new(Mutex::new(schema)),
409406
operations: Arc::new(Mutex::new(vec![])),
410407
headers: HeaderMap::new(),
408+
forward_headers: vec![],
411409
endpoint: "http://localhost:4000".parse().unwrap(),
412410
execute_tool: None,
413411
introspect_tool: None,

crates/apollo-mcp-server/src/server/states/starting.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ impl Starting {
140140
schema,
141141
operations: Arc::new(Mutex::new(operations)),
142142
headers: self.config.headers,
143+
forward_headers: self.config.forward_headers.clone(),
143144
endpoint: self.config.endpoint,
144145
execute_tool,
145146
introspect_tool,
@@ -355,6 +356,7 @@ mod tests {
355356
mutation_mode: MutationMode::All,
356357
execute_introspection: true,
357358
headers: HeaderMap::new(),
359+
forward_headers: vec![],
358360
validate_introspection: true,
359361
introspect_introspection: true,
360362
search_introspection: true,

0 commit comments

Comments
 (0)