Skip to content

Commit 8597cd8

Browse files
committed
fix: refactor chain middleware dynamic dispatch
1 parent 15609b4 commit 8597cd8

File tree

3 files changed

+74
-85
lines changed

3 files changed

+74
-85
lines changed

sentry/src/chain.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
use crate::{Application, ResponseError};
2+
use futures::future::BoxFuture;
23
use hyper::{Body, Request};
34
use primitives::adapter::Adapter;
4-
use std::future::Future;
55

66
// chain middleware function calls
77
//
88
// function signature
99
// fn middleware(mut req: Request) -> Result<Request, ResponseError>
1010

11-
pub async fn chain<'a, A, M, MF>(
11+
pub async fn chain<'a, A: Adapter + 'static, M>(
1212
req: Request<Body>,
1313
app: &'a Application<A>,
1414
middlewares: Vec<M>,
1515
) -> Result<Request<Body>, ResponseError>
1616
where
17-
A: Adapter,
18-
MF: Future<Output = Result<Request<Body>, ResponseError>> + Send,
19-
M: FnMut(Request<Body>, &'a Application<A>) -> MF,
17+
M: FnMut(
18+
Request<Body>,
19+
&'a Application<A>,
20+
) -> BoxFuture<'a, Result<Request<Body>, ResponseError>>
21+
+ 'static,
2022
{
2123
let mut req = Ok(req);
2224

sentry/src/lib.rs

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::middleware::cors::{cors, Cors};
99
use crate::routes::channel::channel_status;
1010
use crate::routes::event_aggregate::list_channel_event_aggregates;
1111
use crate::routes::validator_message::{extract_params, list_validator_messages};
12+
use futures::future::{BoxFuture, FutureExt};
1213
use hyper::{Body, Method, Request, Response, StatusCode};
1314
use lazy_static::lazy_static;
1415
use primitives::adapter::Adapter;
@@ -20,7 +21,6 @@ use routes::cfg::config;
2021
use routes::channel::{channel_list, create_channel, last_approved};
2122
use slog::{error, Logger};
2223
use std::collections::HashMap;
23-
2424
pub mod middleware {
2525
pub mod auth;
2626
pub mod channel;
@@ -54,22 +54,18 @@ lazy_static! {
5454
static ref ADVERTISER_ANALYTICS_BY_CHANNEL_ID: Regex = Regex::new(r"^/analytics/for-advertiser/0x([a-zA-Z0-9]{64})/?$").expect("The regex should be valid");
5555
}
5656

57-
async fn config_middleware<A: Adapter>(
58-
req: Request<Body>,
59-
_: &Application<A>,
60-
) -> Result<Request<Body>, ResponseError> {
61-
Ok(req)
62-
}
63-
64-
async fn auth_required_middleware<A: Adapter>(
57+
fn auth_required_middleware<'a, A: Adapter>(
6558
req: Request<Body>,
6659
_: &Application<A>,
67-
) -> Result<Request<Body>, ResponseError> {
68-
if req.extensions().get::<Session>().is_some() {
69-
Ok(req)
70-
} else {
71-
Err(ResponseError::Unauthorized)
60+
) -> BoxFuture<'a, Result<Request<Body>, ResponseError>> {
61+
async move {
62+
if req.extensions().get::<Session>().is_some() {
63+
Ok(req)
64+
} else {
65+
Err(ResponseError::Unauthorized)
66+
}
7267
}
68+
.boxed()
7369
}
7470

7571
#[derive(Debug)]
@@ -137,8 +133,7 @@ impl<A: Adapter + 'static> Application<A> {
137133

138134
("/analytics", &Method::GET) => analytics(req, &self).await,
139135
("/analytics/for-advertiser", &Method::GET) => {
140-
// @TODO get advertiser channels
141-
let req = match chain(req, &self, vec![auth_required_middleware]).await {
136+
let req = match chain(req, &self, vec![Box::new(auth_required_middleware)]).await {
142137
Ok(req) => req,
143138
Err(error) => {
144139
return map_response_error(error);
@@ -147,7 +142,7 @@ impl<A: Adapter + 'static> Application<A> {
147142
advertiser_analytics(req, &self).await
148143
}
149144
("/analytics/for-publisher", &Method::GET) => {
150-
let req = match chain(req, &self, vec![auth_required_middleware]).await {
145+
let req = match chain(req, &self, vec![Box::new(auth_required_middleware)]).await {
151146
Ok(req) => req,
152147
Err(error) => {
153148
return map_response_error(error);
@@ -170,7 +165,7 @@ impl<A: Adapter + 'static> Application<A> {
170165
}
171166
}
172167

173-
async fn analytics_router<A: Adapter>(
168+
async fn analytics_router<A: Adapter + 'static>(
174169
mut req: Request<Body>,
175170
app: &Application<A>,
176171
) -> Result<Response<Body>, ResponseError> {
@@ -184,15 +179,15 @@ async fn analytics_router<A: Adapter>(
184179
.map_or("".to_string(), |m| m.as_str().to_string())]);
185180
req.extensions_mut().insert(param);
186181

187-
let req = chain(req, app, vec![channel_load]).await?;
182+
let req = chain(req, app, vec![Box::new(channel_load)]).await?;
188183
analytics(req, app).await
189184
} else if let Some(caps) = ADVERTISER_ANALYTICS_BY_CHANNEL_ID.captures(route) {
190185
let param = RouteParams(vec![caps
191186
.get(1)
192187
.map_or("".to_string(), |m| m.as_str().to_string())]);
193188
req.extensions_mut().insert(param);
194189

195-
let req = chain(req, app, vec![auth_required_middleware]).await?;
190+
let req = auth_required_middleware(req, app).await?;
196191
advertiser_analytics(req, app).await
197192
} else {
198193
Err(ResponseError::NotFound)
@@ -202,7 +197,7 @@ async fn analytics_router<A: Adapter>(
202197
}
203198
}
204199

205-
async fn channels_router<A: Adapter>(
200+
async fn channels_router<A: Adapter + 'static>(
206201
mut req: Request<Body>,
207202
app: &Application<A>,
208203
) -> Result<Response<Body>, ResponseError> {
@@ -217,15 +212,6 @@ async fn channels_router<A: Adapter>(
217212
.map_or("".to_string(), |m| m.as_str().to_string())]);
218213
req.extensions_mut().insert(param);
219214

220-
// example with middleware
221-
// @TODO remove later
222-
let req = match chain(req, app, vec![config_middleware]).await {
223-
Ok(req) => req,
224-
Err(error) => {
225-
return Err(error);
226-
}
227-
};
228-
229215
last_approved(req, app).await
230216
} else if let (Some(caps), &Method::GET) =
231217
(CHANNEL_STATUS_BY_CHANNEL_ID.captures(&path), method)
@@ -235,13 +221,7 @@ async fn channels_router<A: Adapter>(
235221
.map_or("".to_string(), |m| m.as_str().to_string())]);
236222
req.extensions_mut().insert(param);
237223

238-
let req = match chain(req, app, vec![channel_load]).await {
239-
Ok(req) => req,
240-
Err(error) => {
241-
return Err(error);
242-
}
243-
};
244-
224+
let req = channel_load(req, app).await?;
245225
channel_status(req, app).await
246226
} else if let (Some(caps), &Method::GET) = (CHANNEL_VALIDATOR_MESSAGES.captures(&path), method)
247227
{
@@ -251,7 +231,7 @@ async fn channels_router<A: Adapter>(
251231

252232
req.extensions_mut().insert(param);
253233

254-
let req = match chain(req, app, vec![channel_load]).await {
234+
let req = match chain(req, app, vec![Box::new(channel_load)]).await {
255235
Ok(req) => req,
256236
Err(error) => {
257237
return Err(error);
@@ -280,7 +260,7 @@ async fn channels_router<A: Adapter>(
280260
]);
281261
req.extensions_mut().insert(param);
282262

283-
let req = chain(req, app, vec![channel_load]).await?;
263+
let req = chain(req, app, vec![Box::new(channel_load)]).await?;
284264

285265
list_channel_event_aggregates(req, app).await
286266
} else {

sentry/src/middleware/channel.rs

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,68 @@
11
use crate::db::{get_channel_by_id, get_channel_by_id_and_validator};
22
use crate::{Application, ResponseError, RouteParams};
3+
use futures::future::{BoxFuture, FutureExt};
34
use hex::FromHex;
45
use hyper::{Body, Request};
56
use primitives::adapter::Adapter;
67
use primitives::{ChannelId, ValidatorId};
78
use std::convert::TryFrom;
89

910
/// channel_load & channel_if_exist
10-
pub async fn channel_load<A: Adapter>(
11+
pub fn channel_load<'a, A: Adapter + 'static>(
1112
mut req: Request<Body>,
12-
app: &Application<A>,
13-
) -> Result<Request<Body>, ResponseError> {
14-
let id = req
15-
.extensions()
16-
.get::<RouteParams>()
17-
.ok_or_else(|| ResponseError::BadRequest("Route params not found".to_string()))?
18-
.get(0)
19-
.ok_or_else(|| ResponseError::BadRequest("No id".to_string()))?;
20-
21-
let channel_id = ChannelId::from_hex(id)
22-
.map_err(|_| ResponseError::BadRequest("Wrong Channel Id".to_string()))?;
23-
let channel = get_channel_by_id(&app.pool, &channel_id)
24-
.await?
25-
.ok_or_else(|| ResponseError::NotFound)?;
26-
27-
req.extensions_mut().insert(channel);
28-
29-
Ok(req)
13+
app: &'a Application<A>,
14+
) -> BoxFuture<'a, Result<Request<Body>, ResponseError>> {
15+
async move {
16+
let id = req
17+
.extensions()
18+
.get::<RouteParams>()
19+
.ok_or_else(|| ResponseError::BadRequest("Route params not found".to_string()))?
20+
.get(0)
21+
.ok_or_else(|| ResponseError::BadRequest("No id".to_string()))?;
22+
23+
let channel_id = ChannelId::from_hex(id)
24+
.map_err(|_| ResponseError::BadRequest("Wrong Channel Id".to_string()))?;
25+
let channel = get_channel_by_id(&app.pool, &channel_id)
26+
.await?
27+
.ok_or_else(|| ResponseError::NotFound)?;
28+
29+
req.extensions_mut().insert(channel);
30+
31+
Ok(req)
32+
}
33+
.boxed()
3034
}
3135

32-
pub async fn channel_if_active<A: Adapter>(
36+
pub fn channel_if_active<'a, A: Adapter + 'static>(
3337
mut req: Request<Body>,
34-
app: &Application<A>,
35-
) -> Result<Request<Body>, ResponseError> {
36-
let route_params = req
37-
.extensions()
38-
.get::<RouteParams>()
39-
.ok_or_else(|| ResponseError::BadRequest("Route params not found".to_string()))?;
38+
app: &'a Application<A>,
39+
) -> BoxFuture<'a, Result<Request<Body>, ResponseError>> {
40+
async move {
41+
let route_params = req
42+
.extensions()
43+
.get::<RouteParams>()
44+
.ok_or_else(|| ResponseError::BadRequest("Route params not found".to_string()))?;
4045

41-
let id = route_params
42-
.get(0)
43-
.ok_or_else(|| ResponseError::BadRequest("No id".to_string()))?;
46+
let id = route_params
47+
.get(0)
48+
.ok_or_else(|| ResponseError::BadRequest("No id".to_string()))?;
4449

45-
let channel_id = ChannelId::from_hex(id)
46-
.map_err(|_| ResponseError::BadRequest("Wrong Channel Id".to_string()))?;
50+
let channel_id = ChannelId::from_hex(id)
51+
.map_err(|_| ResponseError::BadRequest("Wrong Channel Id".to_string()))?;
4752

48-
let validator_id = route_params
49-
.get(1)
50-
.ok_or_else(|| ResponseError::BadRequest("No Validator Id".to_string()))?;
51-
let validator_id = ValidatorId::try_from(&validator_id)
52-
.map_err(|_| ResponseError::BadRequest("Wrong Validator Id".to_string()))?;
53+
let validator_id = route_params
54+
.get(1)
55+
.ok_or_else(|| ResponseError::BadRequest("No Validator Id".to_string()))?;
56+
let validator_id = ValidatorId::try_from(&validator_id)
57+
.map_err(|_| ResponseError::BadRequest("Wrong Validator Id".to_string()))?;
5358

54-
let channel = get_channel_by_id_and_validator(&app.pool, &channel_id, &validator_id)
55-
.await?
56-
.ok_or_else(|| ResponseError::NotFound)?;
59+
let channel = get_channel_by_id_and_validator(&app.pool, &channel_id, &validator_id)
60+
.await?
61+
.ok_or_else(|| ResponseError::NotFound)?;
5762

58-
req.extensions_mut().insert(channel);
63+
req.extensions_mut().insert(channel);
5964

60-
Ok(req)
65+
Ok(req)
66+
}
67+
.boxed()
6168
}

0 commit comments

Comments
 (0)