diff --git a/Cargo.toml b/Cargo.toml index a9e3903..802bd58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,10 @@ tokio = { version = "1.11", default-features = false, features = ["rt-multi-thre tower-http = { version = "0.3", features = ["compression-full", "fs", "set-header", "trace"] } tower = "0.4" fastrand = "1.5" +flate2 = "1.0" brotli = { version = "3", default-features = false, features = ["std"]} rcgen = { version = "0.9", default-features = false } + +[dev-dependencies] +tokio-test = "0.4" +axum-test-helper = "0.1" diff --git a/src/main.rs b/src/main.rs index d97023a..a1ff3f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,7 +38,7 @@ fn main() -> Result<(), anyhow::Error> { let output = wasm_bindgen::generate(&options, &wasm_file)?; - info!("compressed wasm output is {} large", pretty_size(output.compressed_wasm.len())); + info!("compressed wasm output is {} large", pretty_size(output.br_compressed_wasm.len())); let rt = tokio::runtime::Runtime::new()?; rt.block_on(server::run_server(options, output))?; diff --git a/src/server.rs b/src/server.rs index 0d6052c..6e1e4ba 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; use std::net::SocketAddr; +use std::str::from_utf8; use axum::headers::HeaderName; -use axum::http::{HeaderValue, StatusCode, Uri}; +use axum::http::header::ACCEPT_ENCODING; +use axum::http::{HeaderMap, HeaderValue, StatusCode, Uri}; use axum::response::{Html, IntoResponse, Response}; use axum::routing::{get, get_service}; use axum::Router; @@ -29,7 +31,38 @@ pub struct Options { } pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<()> { - let WasmBindgenOutput { js, compressed_wasm, snippets, local_modules } = output; + let app = get_router(&options, output); + let mut address_string = options.address; + if !address_string.contains(":") { + address_string += + &(":".to_owned() + &pick_port::pick_free_port(1334, 10).unwrap_or(1334).to_string()); + } + let addr: SocketAddr = address_string.parse().expect("Couldn't parse address"); + + if options.https { + let certificate = rcgen::generate_simple_self_signed([String::from("localhost")])?; + let config = RustlsConfig::from_der( + vec![certificate.serialize_der()?], + certificate.serialize_private_key_der(), + ) + .await?; + + tracing::info!("starting webserver at https://{}", addr); + axum_server_dual_protocol::bind_dual_protocol(addr, config) + .set_upgrade(true) + .serve(app.into_make_service()) + .await?; + } else { + tracing::info!("starting webserver at http://{}", addr); + axum_server::bind(addr).serve(app.into_make_service()).await?; + } + + Ok(()) +} + +fn get_router(options: &Options, output: WasmBindgenOutput) -> Router { + let WasmBindgenOutput { js, br_compressed_wasm, gzip_compressed_wasm, snippets, local_modules } = + output; let middleware_stack = ServiceBuilder::new() .layer(CompressionLayer::new()) @@ -53,13 +86,40 @@ pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<( let html = html.replace("{{ TITLE }}", &options.title); let serve_dir = - get_service(ServeDir::new(options.directory)).handle_error(internal_server_error); + get_service(ServeDir::new(options.directory.clone())).handle_error(internal_server_error); - let serve_wasm = || async move { - ([("content-encoding", "br")], WithContentType("application/wasm", compressed_wasm)) + let serve_wasm = |headers: HeaderMap| async move { + if let Some(accept_encoding) = headers.get(ACCEPT_ENCODING) { + match from_utf8(accept_encoding.as_bytes()) { + Ok(encodings) => { + let split_encodings: Vec<&str> = encodings.split(",").map(str::trim).collect(); + if split_encodings.contains(&"br") { + Ok(( + [("content-encoding", "br")], + WithContentType("application/wasm", br_compressed_wasm), + )) + } else if split_encodings.contains(&"gzip") { + Ok(( + [("content-encoding", "gzip")], + WithContentType("application/wasm", gzip_compressed_wasm), + )) + } else { + tracing::warn!("Unsupported encoding in request for wasm.wasm"); + Err(( + StatusCode::BAD_REQUEST, + format!("Unsupported encoding(s): {:?}", split_encodings), + )) + } + } + Err(err) => Err((StatusCode::BAD_REQUEST, err.to_string())), + } + } else { + tracing::error!("Received request missing the accept-encoding header"); + Err((StatusCode::BAD_REQUEST, "Missing `accept-encoding` header".to_string())) + } }; - let app = Router::new() + Router::new() .route("/", get(move || async { Html(html) })) .route("/api/wasm.js", get(|| async { WithContentType("application/javascript", js) })) .route("/api/wasm.wasm", get(serve_wasm)) @@ -77,34 +137,7 @@ pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<( }), ) .fallback(serve_dir) - .layer(middleware_stack); - - let mut address_string = options.address; - if !address_string.contains(":") { - address_string += - &(":".to_owned() + &pick_port::pick_free_port(1334, 10).unwrap_or(1334).to_string()); - } - let addr: SocketAddr = address_string.parse().expect("Couldn't parse address"); - - if options.https { - let certificate = rcgen::generate_simple_self_signed([String::from("localhost")])?; - let config = RustlsConfig::from_der( - vec![certificate.serialize_der()?], - certificate.serialize_private_key_der(), - ) - .await?; - - tracing::info!("starting webserver at https://{}", addr); - axum_server_dual_protocol::bind_dual_protocol(addr, config) - .set_upgrade(true) - .serve(app.into_make_service()) - .await?; - } else { - tracing::info!("starting webserver at http://{}", addr); - axum_server::bind(addr).serve(app.into_make_service()).await?; - } - - Ok(()) + .layer(middleware_stack) } fn get_snippet_source( @@ -165,3 +198,78 @@ mod pick_port { .or_else(ask_free_tcp_port) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::server::get_router; + use crate::wasm_bindgen::WasmBindgenOutput; + use crate::Options; + use axum::http::StatusCode; + use axum_test_helper::TestClient; + + const FAKE_BR_COMPRESSED_WASM: [u8; 4] = [1, 2, 3, 4]; + const FAKE_GZIP_COMPRESSED_WASM: [u8; 4] = [0x1f, 0x8b, 0x08, 0x08]; + + fn fake_options() -> Options { + Options { + title: "title".to_string(), + address: "127.0.0.1:0".to_string(), + directory: ".".to_string(), + https: false, + no_module: false, + } + } + + fn fake_wasm_bindgen_output() -> WasmBindgenOutput { + WasmBindgenOutput { + js: "fake js".to_string(), + br_compressed_wasm: FAKE_BR_COMPRESSED_WASM.to_vec(), + gzip_compressed_wasm: FAKE_GZIP_COMPRESSED_WASM.to_vec(), + snippets: HashMap::>::new(), + local_modules: HashMap::::new(), + } + } + + fn make_test_client() -> TestClient { + let options = fake_options(); + let output = fake_wasm_bindgen_output(); + let router = get_router(&options, output); + TestClient::new(router) + } + + #[tokio::test] + async fn test_router_bad_request() { + let client = make_test_client(); + + // Test without any supported compression + let res = client.get("/api/wasm.wasm").header("accept-encoding", "deflate").send().await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn test_router_br() { + let client = make_test_client(); + let mut res = client + .get("/api/wasm.wasm") + .header("accept-encoding", "gzip, deflate, br") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + let result = res.chunk().await.unwrap(); + assert_eq!(result.to_vec(), FAKE_BR_COMPRESSED_WASM); + } + + #[tokio::test] + async fn test_router_gzip() { + let client = make_test_client(); + // Test without br compression, defaulting to gzip + let mut res = + client.get("/api/wasm.wasm").header("accept-encoding", "gzip, deflate").send().await; + assert_eq!(res.status(), StatusCode::OK); + let result = res.chunk().await.unwrap(); + // This is the gzip 3-byte file header + assert_eq!(result.to_vec(), FAKE_GZIP_COMPRESSED_WASM); + } +} diff --git a/src/wasm_bindgen.rs b/src/wasm_bindgen.rs index 551f7b4..e412600 100644 --- a/src/wasm_bindgen.rs +++ b/src/wasm_bindgen.rs @@ -5,9 +5,12 @@ use std::collections::HashMap; use std::path::Path; use tracing::debug; +const COMPRESSION_LEVEL: u32 = 2; + pub struct WasmBindgenOutput { pub js: String, - pub compressed_wasm: Vec, + pub br_compressed_wasm: Vec, + pub gzip_compressed_wasm: Vec, pub snippets: HashMap>, pub local_modules: HashMap, } @@ -35,22 +38,44 @@ pub fn generate(options: &Options, wasm_file: &Path) -> Result Result, std::io::Error> { +fn br_compress(mut bytes: &[u8]) -> Result, std::io::Error> { use brotli::enc::{self, BrotliEncoderParams}; let mut output = Vec::new(); - enc::BrotliCompress(&mut bytes, &mut output, &BrotliEncoderParams { - quality: 5, // https://github.com/jakobhellermann/wasm-server-runner/pull/22#issuecomment-1235804905 - ..Default::default() - })?; + enc::BrotliCompress( + &mut bytes, + &mut output, + &BrotliEncoderParams { + quality: 5, // https://github.com/jakobhellermann/wasm-server-runner/pull/22#issuecomment-1235804905 + ..Default::default() + }, + )?; Ok(output) } + +fn gzip_compress(bytes: &[u8]) -> Result, std::io::Error> { + use flate2::write::GzEncoder; + use flate2::Compression; + use std::io::prelude::*; + + let mut encoder = GzEncoder::new(Vec::new(), Compression::new(COMPRESSION_LEVEL)); + + encoder.write_all(bytes)?; + let compressed = encoder.finish()?; + + Ok(compressed) +}