Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions mullvad-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ crate-type = ["rlib", "staticlib"]
[dependencies]
anyhow = { workspace = true }
async-trait = "0.1"
bytes = "^1.2"
chrono = { workspace = true }
clap = { workspace = true, features = ["cargo", "derive"], optional = true }
clap = { workspace = true, features = ["cargo", "derive"] }
futures = { workspace = true }
http = "1.1.0"
http-body-util = "0.1.2"
hyper = { version = "1.8.1", features = ["client", "http1"] }
hyper = { version = "1.8.1", features = ["client", "http1", "server"] }
hyper-util = { workspace = true }
ipnetwork = { workspace = true }
libc = "0.2"
Expand All @@ -28,6 +29,7 @@ mullvad-encrypted-dns-proxy = { path = "../mullvad-encrypted-dns-proxy" }
mullvad-fs = { path = "../mullvad-fs" }
mullvad-types = { path = "../mullvad-types" }
mullvad-version = { path = "../mullvad-version" }
papaya = { version = "0.2" }
rustls-pki-types.workspace = true
serde = { workspace = true }
serde_json = { workspace = true }
Expand All @@ -51,10 +53,10 @@ tokio-rustls = { version = "0.26.0", default-features = false, features = [
] }
tokio-socks = "0.5.1"
tower = { workspace = true }
tracing-subscriber = { workspace = true, optional = true }
tracing-subscriber = { workspace = true }
uuid = { version = "1.4.1", features = ["v4"] }
vec1 = { workspace = true, features = ["serde"] }
webpki-roots = { workspace = true, optional = true }
webpki-roots = { version = "1.0.4" }

[dev-dependencies]
mockito = "1.6.1"
Expand All @@ -70,7 +72,6 @@ mullvad-update = { path = "../mullvad-update", features = ["client"] }
[features]
# Allow the API server to use to be configured via MULLVAD_API_HOST and MULLVAD_API_ADDR.
api-override = []
domain-fronting = ["clap", "tracing-subscriber", "webpki-roots"]

[lints]
workspace = true
125 changes: 60 additions & 65 deletions mullvad-api/src/bin/domain_fronting.rs
Original file line number Diff line number Diff line change
@@ -1,78 +1,73 @@
#[tokio::main]
async fn main() -> anyhow::Result<()> {
imp::main().await
}

#[cfg(not(feature = "domain-fronting"))]
pub mod imp {
pub async fn main() -> anyhow::Result<()> {
unimplemented!(
"cargo run -p mullvad-api --features domain-fronting --bin domain_fronting -- --front <FRONT_DOMAIN> --host <HOST_DOMAIN>"
)
}
}
use std::{io::Write, sync::Arc};

#[cfg(feature = "domain-fronting")]
mod imp {
use clap::Parser;
use http::{Method, Request};
use http_body_util::{BodyExt, Empty};
use hyper::body::Bytes;
use hyper_util::rt::TokioIo;
use mullvad_api::domain_fronting::DomainFronting;
use tracing_subscriber::{EnvFilter, filter::LevelFilter};
use clap::Parser;
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper_util::client::legacy::Client;
use mullvad_api::{
domain_fronting::DomainFronting,
https_client_with_sni::HttpsConnectorWithSni,
proxy::{ApiConnectionMode, ProxyConfig},
};
use tracing_subscriber::{EnvFilter, filter::LevelFilter};

#[derive(Parser, Debug)]
pub struct Arguments {
/// The domain used to hide the actual destination.
#[arg(long)]
front: String,
#[derive(Parser, Debug)]
pub struct Arguments {
/// The domain used to hide the actual destination.
#[arg(long)]
front: String,

/// The host being reached via `front`.
#[arg(long)]
host: String,
}
/// The host being reached via `front`.
#[arg(long)]
host: String,

pub async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.init();
/// Session header key used to identify client sessions
#[clap(short = 's', long)]
session_header: String,
}

let Arguments { front, host } = Arguments::parse();
println!("front: {:?} host: {:?}", front, host);
let domain_front = DomainFronting::new(front.clone());
let tls_stream = domain_front
.connect()
.await
.expect("Could not resolve {front}");
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.init();

let io = TokioIo::new(tls_stream);
let Arguments {
front,
host,
session_header,
} = Arguments::parse();

let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
let df = DomainFronting::new(front, host, session_header);

tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});
let proxy_config = df.proxy_config().await.unwrap();

// Build the request
let req = Request::builder()
.method(Method::GET)
.header(hyper::header::HOST, host)
.header(hyper::header::ACCEPT, "*/*")
.body(Empty::<Bytes>::new())?;
println!("request: {:?}", req);
let res = sender.send_request(req).await?;
let (connector, connector_handle) = HttpsConnectorWithSni::new(
Arc::new(mullvad_api::DefaultDnsResolver),
#[cfg(feature = "api-override")]
false,
);

println!("Response: {}", res.status());
println!("Headers: {:#?}\n", res.headers());
connector_handle.set_connection_mode(ApiConnectionMode::Proxied(ProxyConfig::DomainFronting(
proxy_config,
)));

// Print the response to stdout
let body = res.collect().await?.to_bytes();
tokio::io::copy(&mut body.as_ref(), &mut tokio::io::stdout()).await?;
let client: Client<_, Full<Bytes>> =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(connector);

println!("\n\nDone!");
Ok(())
}
let response = client
.get("https://api.mullvad.net/app/v1/relays".try_into().unwrap())
.await
.unwrap();
log::info!("Response status: {}", response.status());
log::debug!("Response headers: {:?}", response.headers());
let body = response
.collect()
.await
.expect("failed to fetch response body")
.to_bytes();
let _ = std::io::stdout().write(&body);
let _ = std::io::stdout().write(b"\n");
Ok(())
}
146 changes: 146 additions & 0 deletions mullvad-api/src/bin/domain_fronting_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use clap::Parser;
use futures::FutureExt;
use hyper::{server::conn::http1, service::service_fn};
use hyper_util::rt::TokioIo;
use mullvad_api::domain_fronting::server::Sessions;
use rustls_pki_types::{CertificateDer, pem::PemObject};
use std::{
fs::File,
io::BufReader,
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::net::TcpListener;
use tokio_rustls::{TlsAcceptor, rustls::ServerConfig};
use tracing_subscriber::{EnvFilter, filter::LevelFilter};

#[derive(Parser, Debug)]
#[clap(name = "domain_fronting_server")]
struct Args {
/// Hostname for the server
#[clap(short = 'H', long)]
hostname: String,

/// Path to certificate file (PEM format). If omitted, plain TCP is used.
#[clap(short = 'c', long)]
cert_path: Option<PathBuf>,

/// Path to private key file (PEM format). Required if cert_path is set.
#[clap(short = 'k', long)]
key_path: Option<PathBuf>,

/// Upstream socket address to forward CONNECT requests to
#[clap(short = 'u', long)]
upstream: SocketAddr,

/// Port to listen on
#[clap(short, long, default_value = "443")]
port: u16,

/// Session header key used to identify client sessions
#[clap(short = 's', long)]
session_header: String,
}

fn load_tls_config(
cert_path: &Path,
key_path: &Path,
) -> anyhow::Result<ServerConfig> {
// Load certificate chain
let cert_file = File::open(cert_path)?;
let cert_chain =
CertificateDer::pem_reader_iter(&mut std::io::BufReader::new(BufReader::new(cert_file)))
.collect::<Result<Vec<_>, _>>()?;

// Load private key
let key = rustls_pki_types::PrivateKeyDer::from_pem_file(key_path)?;

// Create server configuration
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;

Ok(config)
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.init();

let Args {
hostname,
cert_path,
key_path,
upstream,
port,
session_header,
} = Args::parse();
let bind_addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?;

let tls_acceptor = match (cert_path, key_path) {
(Some(cert_path), Some(key_path)) => {
log::info!("Starting TLS domain fronting server on {}", bind_addr);
log::info!("Cert path: {}", cert_path.display());
log::info!("Key path: {}", key_path.display());
let tls_config =
tokio::task::spawn_blocking(move || load_tls_config(&cert_path, &key_path)).await?;
Some(TlsAcceptor::from(Arc::new(tls_config?)))
}
(None, None) => {
log::info!("Starting plain TCP domain fronting server on {}", bind_addr);
log::warn!("No TLS certificate provided - running without encryption");
None
}
_ => {
return Err("Both --cert-path and --key-path must be provided together".into());
}
};

log::info!("Hostname: {}", hostname);
log::info!("Upstream: {}", upstream);

let listener = TcpListener::bind(bind_addr).await?;

let sessions = Sessions::new(upstream, session_header);
loop {
let (stream, addr) = listener.accept().await?;

log::debug!("Accepted connection from {addr}");

let sessions = sessions.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
match tls_acceptor {
Some(acceptor) => match acceptor.accept(stream).await {
Ok(tls_stream) => {
serve_connection(TokioIo::new(tls_stream), sessions, addr).await;
}
Err(err) => {
log::error!("TLS handshake failed for {addr}: {err}");
}
},
None => {
serve_connection(TokioIo::new(stream), sessions, addr).await;
}
}
});
}
}

async fn serve_connection<S>(io: S, sessions: Arc<Sessions>, addr: SocketAddr)
where
S: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
{
let service = service_fn(move |req| sessions.clone().handle_request(req).map(Ok::<_, String>));

if let Err(err) = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades()
.await
{
log::error!("Error serving connection from {addr}: {err}");
}
}
Loading
Loading