Skip to content

Commit d66de1a

Browse files
committed
refactor: create context middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent e428cbe commit d66de1a

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

crates/service/src/error.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ use thiserror::Error;
1717
pub enum IndexerServiceError {
1818
#[error("No Tap receipt was found in the request")]
1919
ReceiptNotFound,
20+
#[error("Could not find deployment id")]
21+
DeploymentIdNotFound,
22+
#[error(transparent)]
23+
AxumError(#[from] axum::Error),
24+
25+
#[error(transparent)]
26+
SerializationError(#[from] serde_json::Error),
2027

2128
#[error("Issues with provided receipt: {0}")]
2229
ReceiptError(#[from] tap_core::Error),
@@ -57,6 +64,9 @@ impl IntoResponse for IndexerServiceError {
5764
| EscrowAccount(_)
5865
| ProcessingError(_) => StatusCode::BAD_REQUEST,
5966
ReceiptNotFound => StatusCode::PAYMENT_REQUIRED,
67+
DeploymentIdNotFound => StatusCode::INTERNAL_SERVER_ERROR,
68+
AxumError(_) => StatusCode::BAD_REQUEST,
69+
SerializationError(_) => StatusCode::BAD_REQUEST,
6070
};
6171
tracing::error!(%self, "An IndexerServiceError occoured.");
6272
(

crates/service/src/middleware.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
mod auth;
55
mod inject_allocation;
6+
mod inject_context;
67
mod inject_deployment;
78
mod inject_labels;
89
mod inject_receipt;
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! Injects tap context to be used by the checks
2+
//!
3+
//! Requires Deployment Id extension to available
4+
5+
use serde_json::value::RawValue;
6+
use std::sync::Arc;
7+
8+
use axum::{
9+
body::to_bytes,
10+
extract::{Path, Request},
11+
middleware::Next,
12+
response::Response,
13+
RequestExt,
14+
};
15+
use tap_core::receipt::Context;
16+
use thegraph_core::DeploymentId;
17+
18+
use crate::{error::IndexerServiceError, tap::AgoraQuery};
19+
20+
#[derive(Debug, serde::Deserialize, serde::Serialize)]
21+
struct QueryBody {
22+
query: String,
23+
variables: Option<Box<RawValue>>,
24+
}
25+
26+
pub async fn context_middleware(
27+
mut request: Request,
28+
next: Next,
29+
) -> Result<Response, IndexerServiceError> {
30+
let deployment_id = match request.extensions().get::<DeploymentId>() {
31+
Some(deployment) => *deployment,
32+
None => match request.extract_parts::<Path<DeploymentId>>().await {
33+
Ok(Path(deployment)) => deployment,
34+
Err(_) => return Err(IndexerServiceError::DeploymentIdNotFound),
35+
},
36+
};
37+
38+
let (mut parts, body) = request.into_parts();
39+
let bytes = to_bytes(body, usize::MAX).await?;
40+
let query_body: QueryBody = serde_json::from_slice(&bytes)?;
41+
42+
let variables = query_body
43+
.variables
44+
.as_ref()
45+
.map(ToString::to_string)
46+
.unwrap_or_default();
47+
48+
let mut ctx = Context::new();
49+
ctx.insert(AgoraQuery {
50+
deployment_id,
51+
query: query_body.query.clone(),
52+
variables,
53+
});
54+
parts.extensions.insert(Arc::new(ctx));
55+
let request = Request::from_parts(parts, bytes.into());
56+
Ok(next.run(request).await)
57+
}
58+
59+
#[cfg(test)]
60+
mod tests {
61+
use std::sync::Arc;
62+
63+
use axum::{
64+
body::Body,
65+
http::{Extensions, Request},
66+
middleware::from_fn,
67+
routing::get,
68+
Router,
69+
};
70+
use reqwest::StatusCode;
71+
use tap_core::receipt::Context;
72+
use test_assets::ESCROW_SUBGRAPH_DEPLOYMENT;
73+
use tower::ServiceExt;
74+
75+
use crate::{
76+
middleware::inject_context::{context_middleware, QueryBody},
77+
tap::AgoraQuery,
78+
};
79+
80+
#[tokio::test]
81+
async fn test_context_middleware() {
82+
let middleware = from_fn(context_middleware);
83+
let deployment = *ESCROW_SUBGRAPH_DEPLOYMENT;
84+
let query_body = QueryBody {
85+
query: "hello".to_string(),
86+
variables: None,
87+
};
88+
let body = serde_json::to_string(&query_body).unwrap();
89+
90+
let handle = move |extensions: Extensions| async move {
91+
let ctx = extensions
92+
.get::<Arc<Context>>()
93+
.expect("Should contain context");
94+
let agora = ctx.get::<AgoraQuery>().expect("should contain agora query");
95+
assert_eq!(agora.deployment_id, deployment);
96+
assert_eq!(agora.query, query_body.query);
97+
98+
let variables = query_body
99+
.variables
100+
.as_ref()
101+
.map(ToString::to_string)
102+
.unwrap_or_default();
103+
assert_eq!(agora.variables, variables);
104+
Body::empty()
105+
};
106+
107+
let app = Router::new().route("/", get(handle)).layer(middleware);
108+
109+
let res = app
110+
.oneshot(
111+
Request::builder()
112+
.uri("/")
113+
.extension(deployment)
114+
.extension(deployment)
115+
.body(body)
116+
.unwrap(),
117+
)
118+
.await
119+
.unwrap();
120+
assert_eq!(res.status(), StatusCode::OK);
121+
}
122+
}

0 commit comments

Comments
 (0)