Skip to content

Commit 3587e67

Browse files
committed
refactor: create context middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent 906fb3f commit 3587e67

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

crates/service/src/error.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ use thiserror::Error;
1717
pub enum IndexerServiceError {
1818
#[error("No Tap receipt was found in the request")]
1919
ReceiptNotFound,
20+
#[error("Could not find deployment id")]
21+
DeploymentIdNotFound,
22+
#[error(transparent)]
23+
AxumError(#[from] axum::Error),
24+
25+
#[error(transparent)]
26+
SerializationError(#[from] serde_json::Error),
2027

2128
#[error("Issues with provided receipt: {0}")]
2229
ReceiptError(#[from] tap_core::Error),
@@ -45,6 +52,9 @@ impl IntoResponse for IndexerServiceError {
4552

4653
ReceiptError(_) | EscrowAccount(_) | ProcessingError(_) => StatusCode::BAD_REQUEST,
4754
ReceiptNotFound => StatusCode::PAYMENT_REQUIRED,
55+
DeploymentIdNotFound => StatusCode::INTERNAL_SERVER_ERROR,
56+
AxumError(_) => StatusCode::BAD_REQUEST,
57+
SerializationError(_) => StatusCode::BAD_REQUEST,
4858
};
4959
tracing::error!(%self, "An IndexerServiceError occoured.");
5060
(

crates/service/src/middleware.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
pub mod auth;
55
mod inject_allocation;
6+
mod inject_context;
67
mod inject_deployment;
78
mod inject_labels;
89
mod inject_receipt;
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Injects tap context to be used by the checks
5+
//!
6+
//! Requires Deployment Id extension to available
7+
8+
use serde_json::value::RawValue;
9+
use std::sync::Arc;
10+
11+
use axum::{
12+
body::to_bytes,
13+
extract::{Path, Request},
14+
middleware::Next,
15+
response::Response,
16+
RequestExt,
17+
};
18+
use tap_core::receipt::Context;
19+
use thegraph_core::DeploymentId;
20+
21+
use crate::{error::IndexerServiceError, tap::AgoraQuery};
22+
23+
#[derive(Debug, serde::Deserialize, serde::Serialize)]
24+
struct QueryBody {
25+
query: String,
26+
variables: Option<Box<RawValue>>,
27+
}
28+
29+
pub async fn context_middleware(
30+
mut request: Request,
31+
next: Next,
32+
) -> Result<Response, IndexerServiceError> {
33+
let deployment_id = match request.extensions().get::<DeploymentId>() {
34+
Some(deployment) => *deployment,
35+
None => match request.extract_parts::<Path<DeploymentId>>().await {
36+
Ok(Path(deployment)) => deployment,
37+
Err(_) => return Err(IndexerServiceError::DeploymentIdNotFound),
38+
},
39+
};
40+
41+
let (mut parts, body) = request.into_parts();
42+
let bytes = to_bytes(body, usize::MAX).await?;
43+
let query_body: QueryBody = serde_json::from_slice(&bytes)?;
44+
45+
let variables = query_body
46+
.variables
47+
.as_ref()
48+
.map(ToString::to_string)
49+
.unwrap_or_default();
50+
51+
let mut ctx = Context::new();
52+
ctx.insert(AgoraQuery {
53+
deployment_id,
54+
query: query_body.query.clone(),
55+
variables,
56+
});
57+
parts.extensions.insert(Arc::new(ctx));
58+
let request = Request::from_parts(parts, bytes.into());
59+
Ok(next.run(request).await)
60+
}
61+
62+
#[cfg(test)]
63+
mod tests {
64+
use std::sync::Arc;
65+
66+
use axum::{
67+
body::Body,
68+
http::{Extensions, Request},
69+
middleware::from_fn,
70+
routing::get,
71+
Router,
72+
};
73+
use reqwest::StatusCode;
74+
use tap_core::receipt::Context;
75+
use test_assets::ESCROW_SUBGRAPH_DEPLOYMENT;
76+
use tower::ServiceExt;
77+
78+
use crate::{
79+
middleware::inject_context::{context_middleware, QueryBody},
80+
tap::AgoraQuery,
81+
};
82+
83+
#[tokio::test]
84+
async fn test_context_middleware() {
85+
let middleware = from_fn(context_middleware);
86+
let deployment = *ESCROW_SUBGRAPH_DEPLOYMENT;
87+
let query_body = QueryBody {
88+
query: "hello".to_string(),
89+
variables: None,
90+
};
91+
let body = serde_json::to_string(&query_body).unwrap();
92+
93+
let handle = move |extensions: Extensions| async move {
94+
let ctx = extensions
95+
.get::<Arc<Context>>()
96+
.expect("Should contain context");
97+
let agora = ctx.get::<AgoraQuery>().expect("should contain agora query");
98+
assert_eq!(agora.deployment_id, deployment);
99+
assert_eq!(agora.query, query_body.query);
100+
101+
let variables = query_body
102+
.variables
103+
.as_ref()
104+
.map(ToString::to_string)
105+
.unwrap_or_default();
106+
assert_eq!(agora.variables, variables);
107+
Body::empty()
108+
};
109+
110+
let app = Router::new().route("/", get(handle)).layer(middleware);
111+
112+
let res = app
113+
.oneshot(
114+
Request::builder()
115+
.uri("/")
116+
.extension(deployment)
117+
.extension(deployment)
118+
.body(body)
119+
.unwrap(),
120+
)
121+
.await
122+
.unwrap();
123+
assert_eq!(res.status(), StatusCode::OK);
124+
}
125+
}

0 commit comments

Comments
 (0)