Skip to content

Commit 81d3df5

Browse files
gabotechsgabotechs
authored andcommitted
chore: improve cors configuration
1 parent 991e8c3 commit 81d3df5

File tree

4 files changed

+101
-24
lines changed

4 files changed

+101
-24
lines changed

server/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ mod _test_tools;
99

1010
mod body;
1111
mod gateway_callbacks;
12-
mod route_cors;
1312
mod route_gateway;
1413
mod secret_getter;
1514
mod server;

server/src/route_cors.rs

Lines changed: 0 additions & 14 deletions
This file was deleted.

server/src/server.rs

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ use std::u16;
66
use anyhow::Result;
77
use async_trait::async_trait;
88
use hyper::client::HttpConnector;
9+
use hyper::header::{
10+
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
11+
};
12+
use hyper::http::HeaderValue;
913
use hyper::service::service_fn;
10-
use hyper::{Body, Method, Request, Response};
14+
use hyper::{Body, Method, Request, Response, StatusCode};
1115
use hyper_tls::HttpsConnector;
1216
use tokio::net::TcpListener;
1317
use tracing::{error, info};
@@ -16,6 +20,13 @@ use crate::gateway_callbacks::{CallbackResult, OnRequest, OnSuccess};
1620
use crate::secret_getter::SecretGetter;
1721
use crate::{BytesTransferredInfo, OnBytesTransferred};
1822

23+
fn ok() -> Response<Body> {
24+
Response::builder()
25+
.status(StatusCode::OK)
26+
.body(Body::empty())
27+
.unwrap()
28+
}
29+
1930
pub struct SignwayServer<T: SecretGetter + 'static> {
2031
pub port: u16,
2132
pub secret_getter: T,
@@ -24,6 +35,9 @@ pub struct SignwayServer<T: SecretGetter + 'static> {
2435
pub on_bytes_transferred: Arc<dyn OnBytesTransferred>,
2536
pub(crate) monitor_bytes: bool,
2637
pub(crate) client: hyper::Client<HttpsConnector<HttpConnector>, Body>,
38+
pub(crate) access_control_allow_origin: HeaderValue,
39+
pub(crate) access_control_allow_methods: HeaderValue,
40+
pub(crate) access_control_allow_headers: HeaderValue,
2741
}
2842

2943
pub(crate) struct NoneCallback;
@@ -60,6 +74,9 @@ impl<T: SecretGetter> SignwayServer<T> {
6074
on_success: Box::new(NoneCallback {}),
6175
on_bytes_transferred: Arc::new(NoneCallback {}),
6276
monitor_bytes: false,
77+
access_control_allow_origin: HeaderValue::from_static("*"),
78+
access_control_allow_headers: HeaderValue::from_static("*"),
79+
access_control_allow_methods: HeaderValue::from_static("*"),
6380
client,
6481
}
6582
}
@@ -74,6 +91,9 @@ impl<T: SecretGetter> SignwayServer<T> {
7491
on_success: Box::new(NoneCallback {}),
7592
on_bytes_transferred: Arc::new(NoneCallback {}),
7693
monitor_bytes: false,
94+
access_control_allow_origin: HeaderValue::from_static("*"),
95+
access_control_allow_headers: HeaderValue::from_static("*"),
96+
access_control_allow_methods: HeaderValue::from_static("*"),
7797
client,
7898
}
7999
}
@@ -94,6 +114,38 @@ impl<T: SecretGetter> SignwayServer<T> {
94114
self
95115
}
96116

117+
pub fn access_control_allow_origin(mut self, value: &str) -> Result<Self> {
118+
self.access_control_allow_origin = value.parse()?;
119+
Ok(self)
120+
}
121+
122+
pub fn access_control_allow_methods(mut self, value: &str) -> Result<Self> {
123+
self.access_control_allow_methods = value.parse()?;
124+
Ok(self)
125+
}
126+
127+
pub fn access_control_allow_headers(mut self, value: &str) -> Result<Self> {
128+
self.access_control_allow_headers = value.parse()?;
129+
Ok(self)
130+
}
131+
132+
fn with_cors_headers<B>(&self, mut res: Response<B>) -> Response<B> {
133+
let h = res.headers_mut();
134+
h.insert(
135+
ACCESS_CONTROL_ALLOW_ORIGIN,
136+
self.access_control_allow_origin.clone(),
137+
);
138+
h.insert(
139+
ACCESS_CONTROL_ALLOW_METHODS,
140+
self.access_control_allow_methods.clone(),
141+
);
142+
h.insert(
143+
ACCESS_CONTROL_ALLOW_HEADERS,
144+
self.access_control_allow_headers.clone(),
145+
);
146+
res
147+
}
148+
97149
pub async fn start(self) -> Result<()> {
98150
let in_addr: SocketAddr = ([0, 0, 0, 0], self.port).into();
99151

@@ -109,10 +161,15 @@ impl<T: SecretGetter> SignwayServer<T> {
109161
let service = service_fn(move |req| {
110162
let arc_self = arc_self.clone();
111163
async move {
112-
if req.method() == Method::OPTIONS {
113-
Self::route_cors(&arc_self, req).await
164+
let res = if req.method() == Method::OPTIONS {
165+
Ok(ok())
166+
} else {
167+
arc_self.route_gateway(req).await
168+
};
169+
if let Ok(res) = res {
170+
Ok(arc_self.with_cors_headers(res))
114171
} else {
115-
Self::route_gateway(&arc_self, req).await
172+
res
116173
}
117174
}
118175
});
@@ -192,6 +249,13 @@ mod tests {
192249

193250
let status = response.status();
194251
assert_eq!(status, StatusCode::OK);
252+
assert_eq!(
253+
response
254+
.headers()
255+
.get("access-control-allow-origin")
256+
.unwrap_or(&HeaderValue::from_str("NONE").unwrap()),
257+
"*"
258+
);
195259
}
196260

197261
#[tokio::test]

src/main.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use signway_server::{
1212
SecretGetterResult, SignwayServer,
1313
};
1414

15-
#[derive(Parser, Debug)]
15+
#[derive(Parser, Debug, Clone)]
1616
#[command(author, version, about, long_about = None)]
1717
struct Args {
1818
#[arg(help = "access id that is expected to sign urls for this server")]
@@ -32,6 +32,24 @@ struct Args {
3232
help = "disables the bytes transferred monitoring. This feature is experimental, because hyper does not put things easy for tracking IO results in the responses, and the current implementation might have some performance implications. https://github.com/hyperium/hyper/issues/2181"
3333
)]
3434
no_bytes_monitor: bool,
35+
36+
#[arg(
37+
long,
38+
help = "sets the Access-Control-Allow-Origin that will be answered in each request"
39+
)]
40+
access_control_allow_origin: Option<String>,
41+
42+
#[arg(
43+
long,
44+
help = "sets the Access-Control-Allow-Methods that will be answered in each request"
45+
)]
46+
access_control_allow_methods: Option<String>,
47+
48+
#[arg(
49+
long,
50+
help = "sets the Access-Control-Allow-Headers that will be answered in each request"
51+
)]
52+
access_control_allow_headers: Option<String>,
3553
}
3654

3755
struct Config {
@@ -98,14 +116,24 @@ impl OnBytesTransferred for BytesTransferredLogger {
98116

99117
#[tokio::main]
100118
async fn main() -> anyhow::Result<()> {
101-
let args: Args = Args::parse();
102-
let disable_monitoring = args.no_bytes_monitor;
103-
let config: Config = args.try_into()?;
104119
tracing_subscriber::fmt().json().init();
120+
121+
let args: Args = Args::parse();
122+
let config: Config = args.clone().try_into()?;
105123
let mut server = SignwayServer::from_env(config);
106-
if !disable_monitoring {
124+
125+
if !args.no_bytes_monitor {
107126
server = server.on_bytes_transferred(BytesTransferredLogger {});
108127
}
128+
if let Some(value) = args.access_control_allow_headers {
129+
server = server.access_control_allow_headers(&value)?;
130+
}
131+
if let Some(value) = args.access_control_allow_methods {
132+
server = server.access_control_allow_methods(&value)?;
133+
}
134+
if let Some(value) = args.access_control_allow_origin {
135+
server = server.access_control_allow_origin(&value)?;
136+
}
109137

110138
tokio::select! {
111139
result = server.start() => {

0 commit comments

Comments
 (0)