Skip to content

Commit 374a6bb

Browse files
committed
feat: add static file middleware
1 parent 7d16e82 commit 374a6bb

File tree

5 files changed

+145
-10
lines changed

5 files changed

+145
-10
lines changed

crates/raw/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ license = "CC-BY-NC-ND-4.0"
1010
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1111

1212
[dependencies]
13-
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
13+
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs"] }
1414
hyper = { version = "0.14", features = ["client", "server", "http1", "tcp"] }
1515
http = "0.2"
1616
serde = { version = "1", features = ["derive"] }

crates/raw/src/app.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ impl App {
6060
self.middleware.push(middleware(middleware_fn));
6161
}
6262

63+
pub fn add_layer(&mut self, layer: Middleware) {
64+
self.middleware.push(layer);
65+
}
66+
6367
pub async fn listen(self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
6468
let listener = TcpListener::bind(addr).map_err(|err| {
6569
eprintln!("Failed to bind {}: {}", addr, err);
@@ -97,18 +101,25 @@ impl App {
97101
let method = req.method().clone();
98102
let path = req.uri().path().to_string();
99103

100-
let response = if let Some(route_match) = self.router.find(&method, &path) {
101-
let request = Request::new(req, route_match.params);
102-
let handler = route_match.handler;
103-
let middleware = Arc::new(self.middleware.clone());
104-
let next = Next::new(middleware, handler);
105-
next.run(request).await
106-
} else if self.router.allows_path(&path) {
107-
RawError::MethodNotAllowed.into_response()
104+
let (handler, params) = if let Some(route_match) = self.router.find(&method, &path) {
105+
(route_match.handler, route_match.params)
108106
} else {
109-
RawError::NotFound.into_response()
107+
let method_not_allowed = self.router.allows_path(&path);
108+
let fallback = handler(move |_req| async move {
109+
if method_not_allowed {
110+
RawError::MethodNotAllowed.into_response()
111+
} else {
112+
RawError::NotFound.into_response()
113+
}
114+
});
115+
(fallback, std::collections::HashMap::new())
110116
};
111117

118+
let request = Request::new(req, params);
119+
let middleware = Arc::new(self.middleware.clone());
120+
let next = Next::new(middleware, handler);
121+
let response = next.run(request).await;
122+
112123
Ok(response.into_inner())
113124
}
114125
}

crates/raw/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ pub mod middleware;
77
pub mod request;
88
pub mod response;
99
pub mod router;
10+
pub mod static_files;
1011

1112
pub use app::App;
1213
pub use config::RawConfig;
1314
pub use error::RawError;
1415
pub use request::Request;
1516
pub use response::{Html, Json, Response, StatusCode, Text};
17+
pub use static_files::static_files;

crates/raw/src/static_files.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Credit: Ben Ajaero
2+
3+
use std::path::{Path, PathBuf};
4+
5+
use http::Method;
6+
7+
use crate::middleware::{middleware, Middleware, Next};
8+
use crate::request::Request;
9+
use crate::response::{Response, StatusCode};
10+
11+
pub fn static_files(root: impl Into<PathBuf>) -> Middleware {
12+
let root = root.into();
13+
middleware(move |req: Request, next: Next| {
14+
let root = root.clone();
15+
async move {
16+
if req.method() != Method::GET {
17+
return next.run(req).await;
18+
}
19+
20+
let path = match sanitize_path(req.path()) {
21+
Some(path) => path,
22+
None => {
23+
return Response::new(StatusCode::BAD_REQUEST, "Bad Request", "text/plain");
24+
}
25+
};
26+
27+
let candidate = root.join(path);
28+
match tokio::fs::read(&candidate).await {
29+
Ok(contents) => Response::new(StatusCode::OK, contents, content_type_for_path(&candidate)),
30+
Err(_) => next.run(req).await,
31+
}
32+
}
33+
})
34+
}
35+
36+
fn sanitize_path(path: &str) -> Option<PathBuf> {
37+
let trimmed = path.trim_start_matches('/');
38+
if trimmed.contains("..") {
39+
return None;
40+
}
41+
42+
let normalized = if trimmed.is_empty() {
43+
PathBuf::from("index.html")
44+
} else {
45+
PathBuf::from(trimmed)
46+
};
47+
48+
Some(normalized)
49+
}
50+
51+
fn content_type_for_path(path: &Path) -> &'static str {
52+
match path.extension().and_then(|ext| ext.to_str()).unwrap_or("") {
53+
"html" | "htm" => "text/html; charset=utf-8",
54+
"css" => "text/css; charset=utf-8",
55+
"js" => "application/javascript",
56+
"json" => "application/json",
57+
"png" => "image/png",
58+
"jpg" | "jpeg" => "image/jpeg",
59+
"txt" => "text/plain; charset=utf-8",
60+
_ => "application/octet-stream",
61+
}
62+
}
63+
64+
#[cfg(test)]
65+
mod tests {
66+
use super::{content_type_for_path, sanitize_path};
67+
use std::path::Path;
68+
69+
#[test]
70+
fn sanitize_rejects_parent() {
71+
assert!(sanitize_path("/../secret").is_none());
72+
}
73+
74+
#[test]
75+
fn sanitize_defaults_index() {
76+
let path = sanitize_path("/").expect("path");
77+
assert_eq!(path.to_string_lossy(), "index.html");
78+
}
79+
80+
#[test]
81+
fn content_type_defaults_binary() {
82+
assert_eq!(content_type_for_path(Path::new("/tmp/file.bin")), "application/octet-stream");
83+
}
84+
}

crates/raw/tests/static_files.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Credit: Ben Ajaero
2+
3+
use std::fs;
4+
use std::net::TcpListener;
5+
use std::time::{SystemTime, UNIX_EPOCH};
6+
7+
use raw::{static_files, App, Response, StatusCode, Text};
8+
9+
#[tokio::test]
10+
async fn serves_static_file_before_routes() {
11+
let timestamp = SystemTime::now()
12+
.duration_since(UNIX_EPOCH)
13+
.expect("timestamp")
14+
.as_millis();
15+
let root = std::env::temp_dir().join(format!("raw-static-{}", timestamp));
16+
fs::create_dir_all(&root).expect("create dir");
17+
fs::write(root.join("index.html"), "static-content").expect("write file");
18+
19+
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
20+
let addr = listener.local_addr().expect("local addr");
21+
22+
let mut app = App::new();
23+
app.add_layer(static_files(&root));
24+
app.get("/", |_req| async { Response::from(Text::new("dynamic")) });
25+
26+
let server = tokio::spawn(app.serve(listener));
27+
28+
let client = hyper::Client::new();
29+
let uri = format!("http://{}/", addr).parse().expect("uri");
30+
let res = client.get(uri).await.expect("response");
31+
assert_eq!(res.status(), StatusCode::OK);
32+
33+
let body = hyper::body::to_bytes(res).await.expect("body bytes");
34+
assert_eq!(body, "static-content");
35+
36+
server.abort();
37+
let _ = fs::remove_dir_all(&root);
38+
}

0 commit comments

Comments
 (0)