Skip to content

Commit eb305e0

Browse files
committed
refactor: add inject allocation middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent 25b2455 commit eb305e0

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

crates/service/src/middleware.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
mod inject_allocation;
45
mod inject_deployment;
56
mod inject_receipt;
67
mod inject_sender;
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! injects allocation id in extensions
5+
//! - check if allocation id already exists
6+
//! - else, try to fetch allocation id from deployment_id and allocations watcher
7+
//! - execute query
8+
//!
9+
//! Needs signed receipt Extension to be added OR deployment id
10+
11+
use std::collections::HashMap;
12+
13+
use alloy::primitives::Address;
14+
use axum::{
15+
extract::{Request, State},
16+
middleware::Next,
17+
response::Response,
18+
};
19+
use tap_core::receipt::SignedReceipt;
20+
use thegraph_core::DeploymentId;
21+
use tokio::sync::watch;
22+
23+
#[derive(Clone)]
24+
pub struct Allocation(pub Address);
25+
26+
impl From<Allocation> for String {
27+
fn from(value: Allocation) -> Self {
28+
value.0.to_string()
29+
}
30+
}
31+
32+
#[derive(Clone)]
33+
pub struct AllocationState {
34+
pub deployment_to_allocation: watch::Receiver<HashMap<DeploymentId, Address>>,
35+
}
36+
37+
pub async fn allocation_middleware(
38+
State(my_state): State<AllocationState>,
39+
mut request: Request,
40+
next: Next,
41+
) -> Response {
42+
if let Some(receipt) = request.extensions().get::<SignedReceipt>() {
43+
let allocation = receipt.message.allocation_id;
44+
request.extensions_mut().insert(Allocation(allocation));
45+
} else if let Some(deployment_id) = request.extensions().get::<DeploymentId>() {
46+
if let Some(allocation) = my_state
47+
.deployment_to_allocation
48+
.borrow()
49+
.get(deployment_id)
50+
{
51+
request.extensions_mut().insert(Allocation(*allocation));
52+
}
53+
}
54+
55+
next.run(request).await
56+
}
57+
58+
#[cfg(test)]
59+
mod tests {
60+
use crate::middleware::inject_allocation::Allocation;
61+
62+
use super::{allocation_middleware, AllocationState};
63+
64+
use alloy::primitives::Address;
65+
use axum::{
66+
body::Body,
67+
http::{Extensions, Request},
68+
middleware::from_fn_with_state,
69+
routing::get,
70+
Router,
71+
};
72+
use reqwest::StatusCode;
73+
use test_assets::{create_signed_receipt, ESCROW_SUBGRAPH_DEPLOYMENT};
74+
use tokio::sync::watch;
75+
use tower::ServiceExt;
76+
77+
#[tokio::test]
78+
async fn test_allocation_middleware() {
79+
let deployment = *ESCROW_SUBGRAPH_DEPLOYMENT;
80+
let deployment_to_allocation =
81+
watch::channel(vec![(deployment, Address::ZERO)].into_iter().collect()).1;
82+
let state = AllocationState {
83+
deployment_to_allocation,
84+
};
85+
86+
let middleware = from_fn_with_state(state, allocation_middleware);
87+
88+
async fn handle(extensions: Extensions) -> Body {
89+
let allocation = extensions
90+
.get::<Allocation>()
91+
.expect("Should contain allocation");
92+
assert_eq!(allocation.0, Address::ZERO);
93+
Body::empty()
94+
}
95+
96+
let app = Router::new().route("/", get(handle)).layer(middleware);
97+
98+
let receipt = create_signed_receipt(Address::ZERO, 1, 1, 1).await;
99+
100+
// with receipt
101+
let res = app
102+
.clone()
103+
.oneshot(
104+
Request::builder()
105+
.uri("/")
106+
.extension(receipt)
107+
.body(Body::empty())
108+
.unwrap(),
109+
)
110+
.await
111+
.unwrap();
112+
assert_eq!(res.status(), StatusCode::OK);
113+
114+
// with deployment
115+
let res = app
116+
.oneshot(
117+
Request::builder()
118+
.uri("/")
119+
.extension(deployment)
120+
.body(Body::empty())
121+
.unwrap(),
122+
)
123+
.await
124+
.unwrap();
125+
assert_eq!(res.status(), StatusCode::OK);
126+
}
127+
}

0 commit comments

Comments
 (0)