Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ tower = "0.4"
fastrand = "1.5"
brotli = { version = "3", default-features = false, features = ["std"]}
rcgen = { version = "0.9", default-features = false }

[dev-dependencies]
tokio-test = "0.4"
axum-test-helper = "0.1"
119 changes: 88 additions & 31 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::net::SocketAddr;

use axum::headers::HeaderName;
use axum::http::{HeaderValue, StatusCode, Uri};
use axum::http::{HeaderMap, HeaderValue, StatusCode, Uri};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::{get, get_service};
use axum::Router;
Expand All @@ -29,6 +29,36 @@ pub struct Options {
}

pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<()> {
let app = get_router(&options, output);
let mut address_string = options.address;
if !address_string.contains(":") {
address_string +=
&(":".to_owned() + &pick_port::pick_free_port(1334, 10).unwrap_or(1334).to_string());
}
let addr: SocketAddr = address_string.parse().expect("Couldn't parse address");

if options.https {
let certificate = rcgen::generate_simple_self_signed([String::from("localhost")])?;
let config = RustlsConfig::from_der(
vec![certificate.serialize_der()?],
certificate.serialize_private_key_der(),
)
.await?;

tracing::info!("starting webserver at https://{}", addr);
axum_server_dual_protocol::bind_dual_protocol(addr, config)
.set_upgrade(true)
.serve(app.into_make_service())
.await?;
} else {
tracing::info!("starting webserver at http://{}", addr);
axum_server::bind(addr).serve(app.into_make_service()).await?;
}

Ok(())
}
Comment on lines +33 to +59
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is unchanged as compared to lines 82-107 of the original file.


fn get_router(options: &Options, output: WasmBindgenOutput) -> Router {
let WasmBindgenOutput { js, compressed_wasm, snippets, local_modules } = output;

let middleware_stack = ServiceBuilder::new()
Expand All @@ -53,13 +83,13 @@ pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<(
let html = html.replace("{{ TITLE }}", &options.title);

let serve_dir =
get_service(ServeDir::new(options.directory)).handle_error(internal_server_error);
get_service(ServeDir::new(options.directory.clone())).handle_error(internal_server_error);

let serve_wasm = || async move {
([("content-encoding", "br")], WithContentType("application/wasm", compressed_wasm))
};

let app = Router::new()
Router::new()
.route("/", get(move || async { Html(html) }))
.route("/api/wasm.js", get(|| async { WithContentType("application/javascript", js) }))
.route("/api/wasm.wasm", get(serve_wasm))
Expand All @@ -77,34 +107,7 @@ pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<(
}),
)
.fallback(serve_dir)
.layer(middleware_stack);

let mut address_string = options.address;
if !address_string.contains(":") {
address_string +=
&(":".to_owned() + &pick_port::pick_free_port(1334, 10).unwrap_or(1334).to_string());
}
let addr: SocketAddr = address_string.parse().expect("Couldn't parse address");

if options.https {
let certificate = rcgen::generate_simple_self_signed([String::from("localhost")])?;
let config = RustlsConfig::from_der(
vec![certificate.serialize_der()?],
certificate.serialize_private_key_der(),
)
.await?;

tracing::info!("starting webserver at https://{}", addr);
axum_server_dual_protocol::bind_dual_protocol(addr, config)
.set_upgrade(true)
.serve(app.into_make_service())
.await?;
} else {
tracing::info!("starting webserver at http://{}", addr);
axum_server::bind(addr).serve(app.into_make_service()).await?;
}

Ok(())
.layer(middleware_stack)
}

fn get_snippet_source(
Expand Down Expand Up @@ -165,3 +168,57 @@ mod pick_port {
.or_else(ask_free_tcp_port)
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use crate::server::get_router;
use crate::wasm_bindgen::WasmBindgenOutput;
use crate::Options;
use axum::http::StatusCode;
use axum_test_helper::TestClient;

const FAKE_BR_COMPRESSED_WASM: [u8; 4] = [1, 2, 3, 4];

fn fake_options() -> Options {
Options {
title: "title".to_string(),
address: "127.0.0.1:0".to_string(),
directory: ".".to_string(),
https: false,
no_module: false,
}
}

fn fake_wasm_bindgen_output() -> WasmBindgenOutput {
WasmBindgenOutput {
js: "fake js".to_string(),
compressed_wasm: FAKE_BR_COMPRESSED_WASM.to_vec(),
snippets: HashMap::<String, Vec<String>>::new(),
local_modules: HashMap::<String, String>::new(),
}
}

fn make_test_client() -> TestClient {
let options = fake_options();
let output = fake_wasm_bindgen_output();
let router = get_router(&options, output);
TestClient::new(router)
}

#[tokio::test]
async fn test_router() {
let client = make_test_client();

// Test with br compression requested
let mut res = client
.get("/api/wasm.wasm")
.header("accept-encoding", "gzip, deflate, br")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let result = res.chunk().await.unwrap();
assert_eq!(result.to_vec(), FAKE_BR_COMPRESSED_WASM);
}
}