11use anyhow:: { bail, Context } ;
22use duration_str:: deserialize_duration;
3- use serde:: Deserialize ;
3+ use serde:: { Deserialize , Deserializer } ;
44use small_acme:: LetsEncrypt ;
55use std:: time:: Duration ;
66use std:: { env, fs} ;
7+ use std:: collections:: HashSet ;
8+ use headers:: { HeaderValue , Origin } ;
79use tracing:: warn;
810
911const CONFIG_PATH : & str = "/config/config.toml" ;
@@ -12,7 +14,7 @@ const CONFIG_PATH: &str = "/config/config.toml";
1214pub struct Config {
1315 pub file_dir : String ,
1416 #[ serde( default ) ]
15- pub cors : bool ,
17+ pub cors : Option < HashSet < OriginWrapper > > ,
1618 pub admin_config : Option < AdminConfig > ,
1719 pub http : Option < HttpConfig > ,
1820 pub https : Option < HttpsConfig > ,
@@ -94,7 +96,7 @@ fn default_max_upload_size() -> u64 {
9496#[ derive( Deserialize , Debug , Clone , PartialEq ) ]
9597pub struct DomainConfig {
9698 pub domain : String ,
97- pub cors : Option < bool > ,
99+ pub cors : Option < HashSet < OriginWrapper > > ,
98100 pub cache : Option < DomainCacheConfig > ,
99101 pub https : Option < DomainHttpsConfig > ,
100102 pub alias : Option < Vec < String > > ,
@@ -226,6 +228,29 @@ pub fn get_host_path_from_domain(domain: &str) -> (&str, &str) {
226228 }
227229}
228230
231+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
232+ pub struct OriginWrapper ( HeaderValue ) ;
233+
234+ pub ( crate ) fn extract_origin ( data : & Option < HashSet < OriginWrapper > > ) -> Option < HashSet < HeaderValue > > {
235+ data. as_ref ( ) . map ( |set| set. iter ( ) . map ( |o| o. 0 . clone ( ) ) . collect ( ) )
236+ }
237+
238+ impl < ' de > Deserialize < ' de > for OriginWrapper {
239+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
240+ where
241+ D : Deserializer < ' de >
242+ {
243+ let data = String :: deserialize ( deserializer) ?;
244+ let mut parts = data. splitn ( 2 , "://" ) ;
245+ let scheme = parts. next ( ) . expect ( "missing scheme" ) ;
246+ let rest = parts. next ( ) . expect ( "missing scheme" ) ;
247+ let origin = Origin :: try_from_parts ( scheme, rest, None ) . expect ( "invalid Origin" ) ;
248+
249+ Ok ( OriginWrapper ( origin. to_string ( ) . parse ( )
250+ . expect ( "Origin is always a valid HeaderValue" ) ) )
251+ }
252+ }
253+
229254#[ cfg( test) ]
230255mod test {
231256 use std:: env;
0 commit comments