Skip to content

Commit 589abaf

Browse files
authored
Merge pull request #399 from tirr-c/route-middleware
Per-route middleware
2 parents dd9d42d + e4f2f2d commit 589abaf

File tree

4 files changed

+269
-11
lines changed

4 files changed

+269
-11
lines changed

src/endpoint.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use std::sync::Arc;
2+
13
use async_std::future::Future;
24

5+
use crate::middleware::Next;
36
use crate::utils::BoxFuture;
4-
use crate::{response::IntoResponse, Request, Response};
7+
use crate::{response::IntoResponse, Middleware, Request, Response};
58

69
/// An HTTP request handler.
710
///
@@ -63,3 +66,52 @@ where
6366
Box::pin(async move { fut.await.into_response() })
6467
}
6568
}
69+
70+
pub struct MiddlewareEndpoint<E, State> {
71+
endpoint: E,
72+
middleware: Vec<Arc<dyn Middleware<State>>>,
73+
}
74+
75+
impl<E: Clone, State> Clone for MiddlewareEndpoint<E, State> {
76+
fn clone(&self) -> Self {
77+
Self {
78+
endpoint: self.endpoint.clone(),
79+
middleware: self.middleware.clone(),
80+
}
81+
}
82+
}
83+
84+
impl<E, State> std::fmt::Debug for MiddlewareEndpoint<E, State> {
85+
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86+
write!(
87+
fmt,
88+
"MiddlewareEndpoint (length: {})",
89+
self.middleware.len(),
90+
)
91+
}
92+
}
93+
94+
impl<E, State> MiddlewareEndpoint<E, State>
95+
where
96+
E: Endpoint<State>,
97+
{
98+
pub fn wrap_with_middleware(ep: E, middleware: &[Arc<dyn Middleware<State>>]) -> Self {
99+
Self {
100+
endpoint: ep,
101+
middleware: middleware.to_vec(),
102+
}
103+
}
104+
}
105+
106+
impl<E, State: 'static> Endpoint<State> for MiddlewareEndpoint<E, State>
107+
where
108+
E: Endpoint<State>,
109+
{
110+
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, Response> {
111+
let next = Next {
112+
endpoint: &self.endpoint,
113+
next_middleware: &self.middleware,
114+
};
115+
next.run(req)
116+
}
117+
}

src/router.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use route_recognizer::{Match, Params, Router as MethodRouter};
22
use std::collections::HashMap;
33

4-
use crate::endpoint::{DynEndpoint, Endpoint};
4+
use crate::endpoint::DynEndpoint;
55
use crate::utils::BoxFuture;
66
use crate::{Request, Response};
77

@@ -29,15 +29,15 @@ impl<State: 'static> Router<State> {
2929
}
3030
}
3131

32-
pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: impl Endpoint<State>) {
32+
pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: Box<DynEndpoint<State>>) {
3333
self.method_map
3434
.entry(method)
3535
.or_insert_with(MethodRouter::new)
36-
.add(path, Box::new(ep))
36+
.add(path, ep)
3737
}
3838

39-
pub(crate) fn add_all(&mut self, path: &str, ep: impl Endpoint<State>) {
40-
self.all_method_router.add(path, Box::new(ep))
39+
pub(crate) fn add_all(&mut self, path: &str, ep: Box<DynEndpoint<State>>) {
40+
self.all_method_router.add(path, ep)
4141
}
4242

4343
pub(crate) fn route(&self, path: &str, method: http::Method) -> Selection<'_, State> {

src/server/route.rs

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
use std::sync::Arc;
2+
3+
use crate::endpoint::MiddlewareEndpoint;
14
use crate::utils::BoxFuture;
2-
use crate::{router::Router, Endpoint, Response};
5+
use crate::{router::Router, Endpoint, Middleware, Response};
36

47
/// A handle to a route.
58
///
@@ -13,6 +16,7 @@ use crate::{router::Router, Endpoint, Response};
1316
pub struct Route<'a, State> {
1417
router: &'a mut Router<State>,
1518
path: String,
19+
middleware: Vec<Arc<dyn Middleware<State>>>,
1620
/// Indicates whether the path of current route is treated as a prefix. Set by
1721
/// [`strip_prefix`].
1822
///
@@ -25,6 +29,7 @@ impl<'a, State: 'static> Route<'a, State> {
2529
Route {
2630
router,
2731
path,
32+
middleware: Vec::new(),
2833
prefix: false,
2934
}
3035
}
@@ -44,6 +49,7 @@ impl<'a, State: 'static> Route<'a, State> {
4449
Route {
4550
router: &mut self.router,
4651
path: p,
52+
middleware: self.middleware.clone(),
4753
prefix: false,
4854
}
4955
}
@@ -60,6 +66,18 @@ impl<'a, State: 'static> Route<'a, State> {
6066
self
6167
}
6268

69+
/// Apply the given middleware to the current route.
70+
pub fn middleware(&mut self, middleware: impl Middleware<State>) -> &mut Self {
71+
self.middleware.push(Arc::new(middleware));
72+
self
73+
}
74+
75+
/// Reset the middleware chain for the current route, if any.
76+
pub fn reset_middleware(&mut self) -> &mut Self {
77+
self.middleware.clear();
78+
self
79+
}
80+
6381
/// Nest a [`Server`] at the current path.
6482
///
6583
/// [`Server`]: struct.Server.html
@@ -78,10 +96,29 @@ impl<'a, State: 'static> Route<'a, State> {
7896
pub fn method(&mut self, method: http::Method, ep: impl Endpoint<State>) -> &mut Self {
7997
if self.prefix {
8098
let ep = StripPrefixEndpoint::new(ep);
81-
self.router.add(&self.path, method.clone(), ep.clone());
99+
let (ep1, ep2): (Box<dyn Endpoint<_>>, Box<dyn Endpoint<_>>) =
100+
if self.middleware.is_empty() {
101+
let ep = Box::new(ep);
102+
(ep.clone(), ep)
103+
} else {
104+
let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware(
105+
ep,
106+
&self.middleware,
107+
));
108+
(ep.clone(), ep)
109+
};
110+
self.router.add(&self.path, method.clone(), ep1);
82111
let wildcard = self.at("*--tide-path-rest");
83-
wildcard.router.add(&wildcard.path, method, ep);
112+
wildcard.router.add(&wildcard.path, method, ep2);
84113
} else {
114+
let ep: Box<dyn Endpoint<_>> = if self.middleware.is_empty() {
115+
Box::new(ep)
116+
} else {
117+
Box::new(MiddlewareEndpoint::wrap_with_middleware(
118+
ep,
119+
&self.middleware,
120+
))
121+
};
85122
self.router.add(&self.path, method, ep);
86123
}
87124
self
@@ -93,10 +130,29 @@ impl<'a, State: 'static> Route<'a, State> {
93130
pub fn all(&mut self, ep: impl Endpoint<State>) -> &mut Self {
94131
if self.prefix {
95132
let ep = StripPrefixEndpoint::new(ep);
96-
self.router.add_all(&self.path, ep.clone());
133+
let (ep1, ep2): (Box<dyn Endpoint<_>>, Box<dyn Endpoint<_>>) =
134+
if self.middleware.is_empty() {
135+
let ep = Box::new(ep);
136+
(ep.clone(), ep)
137+
} else {
138+
let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware(
139+
ep,
140+
&self.middleware,
141+
));
142+
(ep.clone(), ep)
143+
};
144+
self.router.add_all(&self.path, ep1);
97145
let wildcard = self.at("*--tide-path-rest");
98-
wildcard.router.add_all(&wildcard.path, ep);
146+
wildcard.router.add_all(&wildcard.path, ep2);
99147
} else {
148+
let ep: Box<dyn Endpoint<_>> = if self.middleware.is_empty() {
149+
Box::new(ep)
150+
} else {
151+
Box::new(MiddlewareEndpoint::wrap_with_middleware(
152+
ep,
153+
&self.middleware,
154+
))
155+
};
100156
self.router.add_all(&self.path, ep);
101157
}
102158
self

tests/route_middleware.rs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
use futures::future::BoxFuture;
2+
use http_service::Body;
3+
use http_service_mock::make_server;
4+
use tide::Middleware;
5+
6+
struct TestMiddleware(&'static str, &'static str);
7+
8+
impl TestMiddleware {
9+
fn with_header_name(name: &'static str, value: &'static str) -> Self {
10+
Self(name, value)
11+
}
12+
}
13+
14+
impl<State: Send + Sync + 'static> Middleware<State> for TestMiddleware {
15+
fn handle<'a>(
16+
&'a self,
17+
req: tide::Request<State>,
18+
next: tide::Next<'a, State>,
19+
) -> BoxFuture<'a, tide::Response> {
20+
Box::pin(async move {
21+
let res = next.run(req).await;
22+
res.set_header(self.0, self.1)
23+
})
24+
}
25+
}
26+
27+
async fn echo_path<State>(req: tide::Request<State>) -> String {
28+
req.uri().path().to_string()
29+
}
30+
31+
#[test]
32+
fn route_middleware() {
33+
let mut app = tide::new();
34+
let mut foo_route = app.at("/foo");
35+
foo_route // /foo
36+
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
37+
.get(echo_path);
38+
foo_route
39+
.at("/bar") // nested, /foo/bar
40+
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
41+
.get(echo_path);
42+
foo_route // /foo
43+
.post(echo_path)
44+
.reset_middleware()
45+
.put(echo_path);
46+
let mut server = make_server(app.into_http_service()).unwrap();
47+
48+
let req = http::Request::get("/foo").body(Body::empty()).unwrap();
49+
let res = server.simulate(req).unwrap();
50+
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
51+
52+
let req = http::Request::post("/foo").body(Body::empty()).unwrap();
53+
let res = server.simulate(req).unwrap();
54+
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
55+
56+
let req = http::Request::put("/foo").body(Body::empty()).unwrap();
57+
let res = server.simulate(req).unwrap();
58+
assert_eq!(res.headers().get("X-Foo"), None);
59+
60+
let req = http::Request::get("/foo/bar").body(Body::empty()).unwrap();
61+
let res = server.simulate(req).unwrap();
62+
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
63+
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
64+
}
65+
66+
#[test]
67+
fn app_and_route_middleware() {
68+
let mut app = tide::new();
69+
app.middleware(TestMiddleware::with_header_name("X-Root", "root"));
70+
app.at("/foo")
71+
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
72+
.get(echo_path);
73+
app.at("/bar")
74+
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
75+
.get(echo_path);
76+
let mut server = make_server(app.into_http_service()).unwrap();
77+
78+
let req = http::Request::get("/foo").body(Body::empty()).unwrap();
79+
let res = server.simulate(req).unwrap();
80+
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
81+
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
82+
assert_eq!(res.headers().get("X-Bar"), None);
83+
84+
let req = http::Request::get("/bar").body(Body::empty()).unwrap();
85+
let res = server.simulate(req).unwrap();
86+
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
87+
assert_eq!(res.headers().get("X-Foo"), None);
88+
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
89+
}
90+
91+
#[test]
92+
fn nested_app_with_route_middleware() {
93+
let mut inner = tide::new();
94+
inner.middleware(TestMiddleware::with_header_name("X-Inner", "inner"));
95+
inner
96+
.at("/baz")
97+
.middleware(TestMiddleware::with_header_name("X-Baz", "baz"))
98+
.get(echo_path);
99+
100+
let mut app = tide::new();
101+
app.middleware(TestMiddleware::with_header_name("X-Root", "root"));
102+
app.at("/foo")
103+
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
104+
.get(echo_path);
105+
app.at("/bar")
106+
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
107+
.nest(inner);
108+
let mut server = make_server(app.into_http_service()).unwrap();
109+
110+
let req = http::Request::get("/foo").body(Body::empty()).unwrap();
111+
let res = server.simulate(req).unwrap();
112+
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
113+
assert_eq!(res.headers().get("X-Inner"), None);
114+
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
115+
assert_eq!(res.headers().get("X-Bar"), None);
116+
assert_eq!(res.headers().get("X-Baz"), None);
117+
118+
let req = http::Request::get("/bar/baz").body(Body::empty()).unwrap();
119+
let res = server.simulate(req).unwrap();
120+
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
121+
assert_eq!(
122+
res.headers().get("X-Inner"),
123+
Some(&"inner".parse().unwrap())
124+
);
125+
assert_eq!(res.headers().get("X-Foo"), None);
126+
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
127+
assert_eq!(res.headers().get("X-Baz"), Some(&"baz".parse().unwrap()));
128+
}
129+
130+
#[test]
131+
fn subroute_not_nested() {
132+
let mut app = tide::new();
133+
app.at("/parent") // /parent
134+
.middleware(TestMiddleware::with_header_name("X-Parent", "Parent"))
135+
.get(echo_path);
136+
app.at("/parent/child") // /parent/child, not nested
137+
.middleware(TestMiddleware::with_header_name("X-Child", "child"))
138+
.get(echo_path);
139+
let mut server = make_server(app.into_http_service()).unwrap();
140+
141+
let req = http::Request::get("/parent/child")
142+
.body(Body::empty())
143+
.unwrap();
144+
let res = server.simulate(req).unwrap();
145+
assert_eq!(res.headers().get("X-Parent"), None);
146+
assert_eq!(
147+
res.headers().get("X-Child"),
148+
Some(&"child".parse().unwrap())
149+
);
150+
}

0 commit comments

Comments
 (0)