Skip to content

Commit a2dff47

Browse files
committed
test: add more tests to increase coverage
1 parent 5b59d7f commit a2dff47

File tree

2 files changed

+159
-5
lines changed

2 files changed

+159
-5
lines changed

crates/apollo-mcp-server/src/auth/valid_token.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use url::Url;
1212
/// Note: This is used as a marker to ensure that we have validated this
1313
/// separately from just reading the header itself.
1414
#[derive(Clone, Debug, PartialEq)]
15-
pub(crate) struct ValidToken(pub(super) Authorization<Bearer>);
15+
pub(crate) struct ValidToken(pub(crate) Authorization<Bearer>);
1616

1717
impl Deref for ValidToken {
1818
type Target = Authorization<Bearer>;

crates/apollo-mcp-server/src/headers.rs

Lines changed: 158 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,164 @@ fn forward_headers(names: &[String], incoming: &HeaderMap, outgoing: &mut Header
6464
#[cfg(test)]
6565
mod tests {
6666
use super::*;
67+
use headers::Authorization;
68+
use http::Extensions;
6769
use reqwest::header::HeaderValue;
6870

71+
use crate::auth::ValidToken;
72+
73+
#[test]
74+
fn test_build_request_headers_includes_static_headers() {
75+
let mut static_headers = HeaderMap::new();
76+
static_headers.insert("x-api-key", HeaderValue::from_static("static-key"));
77+
static_headers.insert("user-agent", HeaderValue::from_static("mcp-server"));
78+
79+
let forward_header_names = vec![];
80+
let incoming_headers = HeaderMap::new();
81+
let extensions = Extensions::new();
82+
83+
let result = build_request_headers(
84+
&static_headers,
85+
&forward_header_names,
86+
&incoming_headers,
87+
&extensions,
88+
false,
89+
);
90+
91+
assert_eq!(result.get("x-api-key").unwrap(), "static-key");
92+
assert_eq!(result.get("user-agent").unwrap(), "mcp-server");
93+
}
94+
95+
#[test]
96+
fn test_build_request_headers_forwards_configured_headers() {
97+
let static_headers = HeaderMap::new();
98+
let forward_header_names = vec!["x-tenant-id".to_string(), "x-trace-id".to_string()];
99+
100+
let mut incoming_headers = HeaderMap::new();
101+
incoming_headers.insert("x-tenant-id", HeaderValue::from_static("tenant-123"));
102+
incoming_headers.insert("x-trace-id", HeaderValue::from_static("trace-456"));
103+
incoming_headers.insert("other-header", HeaderValue::from_static("ignored"));
104+
105+
let extensions = Extensions::new();
106+
107+
let result = build_request_headers(
108+
&static_headers,
109+
&forward_header_names,
110+
&incoming_headers,
111+
&extensions,
112+
false,
113+
);
114+
115+
assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123");
116+
assert_eq!(result.get("x-trace-id").unwrap(), "trace-456");
117+
assert!(result.get("other-header").is_none());
118+
}
119+
120+
#[test]
121+
fn test_build_request_headers_adds_oauth_token_when_enabled() {
122+
let static_headers = HeaderMap::new();
123+
let forward_header_names = vec![];
124+
let incoming_headers = HeaderMap::new();
125+
126+
let mut extensions = Extensions::new();
127+
let token = ValidToken(Authorization::bearer("test-token").unwrap());
128+
extensions.insert(token);
129+
130+
let result = build_request_headers(
131+
&static_headers,
132+
&forward_header_names,
133+
&incoming_headers,
134+
&extensions,
135+
false,
136+
);
137+
138+
assert!(result.get("authorization").is_some());
139+
assert_eq!(result.get("authorization").unwrap(), "Bearer test-token");
140+
}
141+
142+
#[test]
143+
fn test_build_request_headers_skips_oauth_token_when_disabled() {
144+
let static_headers = HeaderMap::new();
145+
let forward_header_names = vec![];
146+
let incoming_headers = HeaderMap::new();
147+
148+
let mut extensions = Extensions::new();
149+
let token = ValidToken(Authorization::bearer("test-token").unwrap());
150+
extensions.insert(token);
151+
152+
let result = build_request_headers(
153+
&static_headers,
154+
&forward_header_names,
155+
&incoming_headers,
156+
&extensions,
157+
true,
158+
);
159+
160+
assert!(result.get("authorization").is_none());
161+
}
162+
163+
#[test]
164+
fn test_build_request_headers_forwards_mcp_session_id() {
165+
let static_headers = HeaderMap::new();
166+
let forward_header_names = vec![];
167+
168+
let mut incoming_headers = HeaderMap::new();
169+
incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-123"));
170+
171+
let extensions = Extensions::new();
172+
173+
let result = build_request_headers(
174+
&static_headers,
175+
&forward_header_names,
176+
&incoming_headers,
177+
&extensions,
178+
false,
179+
);
180+
181+
assert_eq!(result.get("mcp-session-id").unwrap(), "session-123");
182+
}
183+
184+
#[test]
185+
fn test_build_request_headers_combined_scenario() {
186+
// Static headers
187+
let mut static_headers = HeaderMap::new();
188+
static_headers.insert("x-api-key", HeaderValue::from_static("static-key"));
189+
190+
// Forward specific headers
191+
let forward_header_names = vec!["x-tenant-id".to_string()];
192+
193+
// Incoming headers
194+
let mut incoming_headers = HeaderMap::new();
195+
incoming_headers.insert("x-tenant-id", HeaderValue::from_static("tenant-123"));
196+
incoming_headers.insert("mcp-session-id", HeaderValue::from_static("session-456"));
197+
incoming_headers.insert(
198+
"ignored-header",
199+
HeaderValue::from_static("should-not-appear"),
200+
);
201+
202+
// OAuth token
203+
let mut extensions = Extensions::new();
204+
let token = ValidToken(Authorization::bearer("oauth-token").unwrap());
205+
extensions.insert(token);
206+
207+
let result = build_request_headers(
208+
&static_headers,
209+
&forward_header_names,
210+
&incoming_headers,
211+
&extensions,
212+
false,
213+
);
214+
215+
// Verify all parts combined correctly
216+
assert_eq!(result.get("x-api-key").unwrap(), "static-key");
217+
assert_eq!(result.get("x-tenant-id").unwrap(), "tenant-123");
218+
assert_eq!(result.get("mcp-session-id").unwrap(), "session-456");
219+
assert_eq!(result.get("authorization").unwrap(), "Bearer oauth-token");
220+
assert!(result.get("ignored-header").is_none());
221+
}
222+
69223
#[test]
70-
fn test_forward_no_headers_by_default() {
224+
fn test_forward_headers_no_headers_by_default() {
71225
let names: Vec<String> = vec![];
72226

73227
let mut incoming = HeaderMap::new();
@@ -81,7 +235,7 @@ mod tests {
81235
}
82236

83237
#[test]
84-
fn test_forward_only_allowed_headers() {
238+
fn test_forward_headers_only_specific_headers() {
85239
let names = vec![
86240
"x-tenant-id".to_string(), // Multi-tenancy
87241
"x-trace-id".to_string(), // Distributed tracing
@@ -112,7 +266,7 @@ mod tests {
112266
}
113267

114268
#[test]
115-
fn test_hop_by_hop_headers_blocked() {
269+
fn test_forward_headers_blocks_hop_by_hop_headers() {
116270
let names = vec!["connection".to_string(), "content-length".to_string()];
117271

118272
let mut incoming = HeaderMap::new();
@@ -128,7 +282,7 @@ mod tests {
128282
}
129283

130284
#[test]
131-
fn test_case_insensitive_matching() {
285+
fn test_forward_headers_case_insensitive_matching() {
132286
let names = vec!["X-Tenant-ID".to_string()];
133287

134288
let mut incoming = HeaderMap::new();

0 commit comments

Comments
 (0)