Skip to content

Commit 5f2a678

Browse files
committed
feat: add CORS support
1 parent 131a9e0 commit 5f2a678

File tree

9 files changed

+394
-6
lines changed

9 files changed

+394
-6
lines changed
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
use http::Method;
2+
use regex::Regex;
3+
use schemars::JsonSchema;
4+
use serde::Deserialize;
5+
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
6+
use url::Url;
7+
8+
use crate::errors::ServerError;
9+
10+
/// CORS configuration options
11+
#[derive(Debug, Clone, Deserialize, JsonSchema)]
12+
#[serde(default)]
13+
pub struct CorsConfig {
14+
/// Enable CORS support
15+
pub enabled: bool,
16+
17+
/// List of allowed origins (exact match)
18+
pub origins: Vec<String>,
19+
20+
/// List of origin patterns (regex matching)
21+
pub match_origins: Vec<String>,
22+
23+
/// Allow any origin (use with caution)
24+
pub allow_any_origin: bool,
25+
26+
/// Allow credentials in CORS requests
27+
pub allow_credentials: bool,
28+
29+
/// Allowed HTTP methods
30+
pub allow_methods: Vec<String>,
31+
32+
/// Allowed request headers
33+
pub allow_headers: Vec<String>,
34+
35+
/// Headers exposed to the browser
36+
pub expose_headers: Vec<String>,
37+
38+
/// Max age for preflight cache (in seconds)
39+
pub max_age: Option<u64>,
40+
}
41+
42+
impl Default for CorsConfig {
43+
fn default() -> Self {
44+
Self {
45+
enabled: false,
46+
origins: Vec::new(),
47+
match_origins: Vec::new(),
48+
allow_any_origin: false,
49+
allow_credentials: false,
50+
allow_methods: default_methods(),
51+
allow_headers: default_headers(),
52+
expose_headers: Vec::new(),
53+
max_age: Some(default_max_age()),
54+
}
55+
}
56+
}
57+
58+
/// Default allowed HTTP methods
59+
fn default_methods() -> Vec<String> {
60+
vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()]
61+
}
62+
63+
/// Default allowed headers
64+
fn default_headers() -> Vec<String> {
65+
vec![
66+
"content-type".to_string(),
67+
"authorization".to_string(),
68+
"mcp-session-id".to_string(),
69+
]
70+
}
71+
72+
/// Default max age for preflight cache (2 hours)
73+
fn default_max_age() -> u64 {
74+
7200
75+
}
76+
77+
impl CorsConfig {
78+
/// Build a CorsLayer from this configuration
79+
pub fn build_cors_layer(&self) -> Result<CorsLayer, ServerError> {
80+
if !self.enabled {
81+
return Err(ServerError::Cors("CORS is not enabled".to_string()));
82+
}
83+
84+
// Validate configuration
85+
self.validate()?;
86+
87+
let mut cors = CorsLayer::new();
88+
89+
// Configure origins
90+
if self.allow_any_origin {
91+
cors = cors.allow_origin(Any);
92+
} else {
93+
// Collect all origins (exact and regex patterns)
94+
let mut origin_list = Vec::new();
95+
96+
// Parse exact origins
97+
for origin_str in &self.origins {
98+
let origin = origin_str.parse::<http::HeaderValue>().map_err(|e| {
99+
ServerError::Cors(format!("Invalid origin '{}': {}", origin_str, e))
100+
})?;
101+
origin_list.push(origin);
102+
}
103+
104+
// For regex patterns, we need to use a predicate function
105+
if !self.match_origins.is_empty() {
106+
// Parse regex patterns to validate them
107+
let mut regex_patterns = Vec::new();
108+
for pattern in &self.match_origins {
109+
let regex = Regex::new(pattern).map_err(|e| {
110+
ServerError::Cors(format!("Invalid origin pattern '{}': {}", pattern, e))
111+
})?;
112+
regex_patterns.push(regex);
113+
}
114+
115+
// Use predicate function that combines exact origins and regex patterns
116+
let exact_origins = origin_list;
117+
cors = cors.allow_origin(AllowOrigin::predicate(move |origin, _| {
118+
let origin_str = origin.to_str().unwrap_or("");
119+
120+
// Check exact origins
121+
if exact_origins
122+
.iter()
123+
.any(|exact| exact.as_bytes() == origin.as_bytes())
124+
{
125+
return true;
126+
}
127+
128+
// Check regex patterns
129+
regex_patterns
130+
.iter()
131+
.any(|regex| regex.is_match(origin_str))
132+
}));
133+
} else if !origin_list.is_empty() {
134+
// Only exact origins, no regex
135+
cors = cors.allow_origin(origin_list);
136+
}
137+
}
138+
139+
// Configure credentials
140+
cors = cors.allow_credentials(self.allow_credentials);
141+
142+
// Configure methods
143+
let methods: Result<Vec<Method>, _> = self
144+
.allow_methods
145+
.iter()
146+
.map(|m| m.parse::<Method>())
147+
.collect();
148+
let methods =
149+
methods.map_err(|e| ServerError::Cors(format!("Invalid HTTP method: {}", e)))?;
150+
cors = cors.allow_methods(methods);
151+
152+
// Configure headers
153+
if !self.allow_headers.is_empty() {
154+
let headers: Result<Vec<http::HeaderName>, _> = self
155+
.allow_headers
156+
.iter()
157+
.map(|h| h.parse::<http::HeaderName>())
158+
.collect();
159+
let headers =
160+
headers.map_err(|e| ServerError::Cors(format!("Invalid header name: {}", e)))?;
161+
cors = cors.allow_headers(headers);
162+
}
163+
164+
// Configure exposed headers
165+
if !self.expose_headers.is_empty() {
166+
let headers: Result<Vec<http::HeaderName>, _> = self
167+
.expose_headers
168+
.iter()
169+
.map(|h| h.parse::<http::HeaderName>())
170+
.collect();
171+
let headers = headers
172+
.map_err(|e| ServerError::Cors(format!("Invalid exposed header name: {}", e)))?;
173+
cors = cors.expose_headers(headers);
174+
}
175+
176+
// Configure max age
177+
if let Some(max_age) = self.max_age {
178+
cors = cors.max_age(std::time::Duration::from_secs(max_age));
179+
}
180+
181+
Ok(cors)
182+
}
183+
184+
/// Validate the configuration for consistency
185+
fn validate(&self) -> Result<(), ServerError> {
186+
// Cannot use credentials with any origin
187+
if self.allow_credentials && self.allow_any_origin {
188+
return Err(ServerError::Cors(
189+
"Cannot use allow_credentials with allow_any_origin for security reasons"
190+
.to_string(),
191+
));
192+
}
193+
194+
// Must have at least some origin configuration if not allowing any origin
195+
if !self.allow_any_origin && self.origins.is_empty() && self.match_origins.is_empty() {
196+
return Err(ServerError::Cors(
197+
"Must specify origins, match_origins, or allow_any_origin when CORS is enabled"
198+
.to_string(),
199+
));
200+
}
201+
202+
// Validate that origin strings are valid URLs
203+
for origin in &self.origins {
204+
Url::parse(origin).map_err(|e| {
205+
ServerError::Cors(format!("Invalid origin URL '{}': {}", origin, e))
206+
})?;
207+
}
208+
209+
// Validate regex patterns
210+
for pattern in &self.match_origins {
211+
Regex::new(pattern).map_err(|e| {
212+
ServerError::Cors(format!("Invalid regex pattern '{}': {}", pattern, e))
213+
})?;
214+
}
215+
216+
Ok(())
217+
}
218+
}
219+
220+
#[cfg(test)]
221+
mod tests {
222+
use super::*;
223+
224+
#[test]
225+
fn test_default_config() {
226+
let config = CorsConfig::default();
227+
assert!(!config.enabled);
228+
assert!(!config.allow_any_origin);
229+
assert!(!config.allow_credentials);
230+
assert_eq!(config.allow_methods, default_methods());
231+
assert_eq!(config.allow_headers, default_headers());
232+
assert_eq!(config.max_age, Some(default_max_age()));
233+
}
234+
235+
#[test]
236+
fn test_disabled_cors_fails_to_build() {
237+
let config = CorsConfig::default();
238+
assert!(config.build_cors_layer().is_err());
239+
}
240+
241+
#[test]
242+
fn test_allow_any_origin_builds() {
243+
let config = CorsConfig {
244+
enabled: true,
245+
allow_any_origin: true,
246+
..Default::default()
247+
};
248+
assert!(config.build_cors_layer().is_ok());
249+
}
250+
251+
#[test]
252+
fn test_specific_origins_build() {
253+
let config = CorsConfig {
254+
enabled: true,
255+
origins: vec![
256+
"https://localhost:3000".to_string(),
257+
"https://studio.apollographql.com".to_string(),
258+
],
259+
..Default::default()
260+
};
261+
assert!(config.build_cors_layer().is_ok());
262+
}
263+
264+
#[test]
265+
fn test_regex_origins_build() {
266+
let config = CorsConfig {
267+
enabled: true,
268+
match_origins: vec!["^https://localhost:[0-9]+$".to_string()],
269+
..Default::default()
270+
};
271+
assert!(config.build_cors_layer().is_ok());
272+
}
273+
274+
#[test]
275+
fn test_credentials_with_any_origin_fails() {
276+
let config = CorsConfig {
277+
enabled: true,
278+
allow_any_origin: true,
279+
allow_credentials: true,
280+
..Default::default()
281+
};
282+
assert!(config.build_cors_layer().is_err());
283+
}
284+
285+
#[test]
286+
fn test_no_origins_fails() {
287+
let config = CorsConfig {
288+
enabled: true,
289+
allow_any_origin: false,
290+
origins: vec![],
291+
match_origins: vec![],
292+
..Default::default()
293+
};
294+
assert!(config.build_cors_layer().is_err());
295+
}
296+
297+
#[test]
298+
fn test_invalid_origin_fails() {
299+
let config = CorsConfig {
300+
enabled: true,
301+
origins: vec!["not-a-valid-url".to_string()],
302+
..Default::default()
303+
};
304+
assert!(config.build_cors_layer().is_err());
305+
}
306+
307+
#[test]
308+
fn test_invalid_regex_fails() {
309+
let config = CorsConfig {
310+
enabled: true,
311+
match_origins: vec!["[invalid regex".to_string()],
312+
..Default::default()
313+
};
314+
assert!(config.build_cors_layer().is_err());
315+
}
316+
317+
#[test]
318+
fn test_invalid_method_fails() {
319+
let config = CorsConfig {
320+
enabled: true,
321+
origins: vec!["https://localhost:3000".to_string()],
322+
allow_methods: vec!["invalid method with spaces".to_string()],
323+
..Default::default()
324+
};
325+
assert!(config.build_cors_layer().is_err());
326+
}
327+
}

crates/apollo-mcp-server/src/errors.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ pub enum ServerError {
100100

101101
#[error("Failed to index schema: {0}")]
102102
Indexing(#[from] IndexingError),
103+
104+
#[error("CORS configuration error: {0}")]
105+
Cors(String),
103106
}
104107

105108
/// An MCP tool error

crates/apollo-mcp-server/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod auth;
2+
pub mod cors;
23
pub mod custom_scalar_map;
34
pub mod errors;
45
pub mod event;

crates/apollo-mcp-server/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ async fn main() -> anyhow::Result<()> {
145145
.search_leaf_depth(config.introspection.search.leaf_depth)
146146
.index_memory_bytes(config.introspection.search.index_memory_bytes)
147147
.health_check(config.health_check)
148+
.cors(config.cors)
148149
.build()
149150
.start()
150151
.await?)

crates/apollo-mcp-server/src/runtime.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,27 @@ mod test {
144144

145145
insta::assert_debug_snapshot!(config, @r#"
146146
Config {
147+
cors: CorsConfig {
148+
enabled: false,
149+
origins: [],
150+
match_origins: [],
151+
allow_any_origin: false,
152+
allow_credentials: false,
153+
allow_methods: [
154+
"GET",
155+
"POST",
156+
"OPTIONS",
157+
],
158+
allow_headers: [
159+
"content-type",
160+
"authorization",
161+
"mcp-session-id",
162+
],
163+
expose_headers: [],
164+
max_age: Some(
165+
7200,
166+
),
167+
},
147168
custom_scalars: None,
148169
endpoint: Endpoint(
149170
Url {

crates/apollo-mcp-server/src/runtime/config.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::path::PathBuf;
22

3-
use apollo_mcp_server::{health::HealthCheckConfig, server::Transport};
3+
use apollo_mcp_server::{cors::CorsConfig, health::HealthCheckConfig, server::Transport};
44
use reqwest::header::HeaderMap;
55
use schemars::JsonSchema;
66
use serde::Deserialize;
@@ -15,6 +15,9 @@ use super::{
1515
#[derive(Debug, Default, Deserialize, JsonSchema)]
1616
#[serde(default)]
1717
pub struct Config {
18+
/// CORS configuration
19+
pub cors: CorsConfig,
20+
1821
/// Path to a custom scalar map
1922
pub custom_scalars: Option<PathBuf>,
2023

0 commit comments

Comments
 (0)