Skip to content

Commit 44b5e26

Browse files
committed
refactor: add auth middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent 1e6ad79 commit 44b5e26

File tree

6 files changed

+271
-0
lines changed

6 files changed

+271
-0
lines changed

crates/service/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ use thiserror::Error;
1515

1616
#[derive(Debug, Error)]
1717
pub enum IndexerServiceError {
18+
#[error("No Tap receipt was found in the request")]
19+
ReceiptNotFound,
20+
1821
#[error("Issues with provided receipt: {0}")]
1922
ReceiptError(#[from] tap_core::Error),
2023
#[error("No attestation signer found for allocation `{0}`")]
@@ -53,6 +56,7 @@ impl IntoResponse for IndexerServiceError {
5356
| InvalidFreeQueryAuthToken
5457
| EscrowAccount(_)
5558
| ProcessingError(_) => StatusCode::BAD_REQUEST,
59+
ReceiptNotFound => StatusCode::PAYMENT_REQUIRED,
5660
};
5761
tracing::error!(%self, "An IndexerServiceError occoured.");
5862
(

crates/service/src/middleware.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
mod auth;
45
mod inject_allocation;
56
mod inject_deployment;
67
mod inject_labels;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
mod bearer;
5+
mod or;
6+
mod tap;
7+
8+
pub use bearer::Bearer;
9+
pub use or::OrExt;
10+
pub use tap::tap_receipt_authorize;
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Bearer struct from tower-http but exposing the `new()` function
5+
//! to allow creation
6+
//!
7+
//! This code is from *tower-http*
8+
9+
use std::{fmt, marker::PhantomData};
10+
11+
use axum::http::{HeaderValue, Request, Response};
12+
use reqwest::{header, StatusCode};
13+
use tower_http::validate_request::ValidateRequest;
14+
15+
pub struct Bearer<ResBody> {
16+
header_value: HeaderValue,
17+
_ty: PhantomData<fn() -> ResBody>,
18+
}
19+
20+
impl<ResBody> Bearer<ResBody> {
21+
pub fn new(token: &str) -> Self
22+
where
23+
ResBody: Default,
24+
{
25+
Self {
26+
header_value: format!("Bearer {}", token)
27+
.parse()
28+
.expect("token is not a valid header value"),
29+
_ty: PhantomData,
30+
}
31+
}
32+
}
33+
34+
impl<ResBody> Clone for Bearer<ResBody> {
35+
fn clone(&self) -> Self {
36+
Self {
37+
header_value: self.header_value.clone(),
38+
_ty: PhantomData,
39+
}
40+
}
41+
}
42+
43+
impl<ResBody> fmt::Debug for Bearer<ResBody> {
44+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45+
f.debug_struct("Bearer")
46+
.field("header_value", &self.header_value)
47+
.finish()
48+
}
49+
}
50+
51+
impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
52+
where
53+
ResBody: Default,
54+
{
55+
type ResponseBody = ResBody;
56+
57+
fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
58+
match request.headers().get(header::AUTHORIZATION) {
59+
Some(actual) if actual == self.header_value => Ok(()),
60+
_ => {
61+
let mut res = Response::new(ResBody::default());
62+
*res.status_mut() = StatusCode::UNAUTHORIZED;
63+
Err(res)
64+
}
65+
}
66+
}
67+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Merge a ValidateRequest and an AsyncAuthorizeRequest
5+
//!
6+
//! executes a ValidateRequest returning the request if it succeeds
7+
//! or else, executed the future and return it
8+
9+
use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};
10+
11+
use axum::http::{Request, Response};
12+
use pin_project::pin_project;
13+
use tower_http::{auth::AsyncAuthorizeRequest, validate_request::ValidateRequest};
14+
15+
pub trait OrExt<T, B, Resp>: Sized {
16+
fn or(self, other: T) -> Or<Self, T, B, Resp>;
17+
}
18+
19+
impl<T, A, B, Resp, Fut> OrExt<A, B, Resp> for T
20+
where
21+
B: 'static + Send,
22+
Resp: 'static + Send,
23+
T: ValidateRequest<B, ResponseBody = Resp>,
24+
A: AsyncAuthorizeRequest<B, RequestBody = B, ResponseBody = Resp, Future = Fut>
25+
+ Clone
26+
+ 'static
27+
+ Send,
28+
Fut: Future<Output = Result<Request<B>, Response<Resp>>> + Send,
29+
{
30+
fn or(self, other: A) -> Or<Self, A, B, Resp> {
31+
Or(self, other, PhantomData)
32+
}
33+
}
34+
35+
pub struct Or<T, E, B, Resp>(T, E, PhantomData<fn(B) -> Resp>);
36+
37+
impl<T, E, B, Resp> Clone for Or<T, E, B, Resp>
38+
where
39+
T: Clone,
40+
E: Clone,
41+
{
42+
fn clone(&self) -> Self {
43+
Self(self.0.clone(), self.1.clone(), self.2)
44+
}
45+
}
46+
47+
impl<T, E, Req, Resp, Fut> AsyncAuthorizeRequest<Req> for Or<T, E, Req, Resp>
48+
where
49+
Req: 'static + Send,
50+
Resp: 'static + Send,
51+
T: ValidateRequest<Req, ResponseBody = Resp>,
52+
E: AsyncAuthorizeRequest<Req, RequestBody = Req, ResponseBody = Resp, Future = Fut>
53+
+ Clone
54+
+ 'static
55+
+ Send,
56+
Fut: Future<Output = Result<Request<Req>, Response<Resp>>> + Send,
57+
{
58+
type RequestBody = Req;
59+
type ResponseBody = Resp;
60+
61+
type Future = OrFuture<Fut, Req, Resp>;
62+
63+
fn authorize(&mut self, mut request: axum::http::Request<Req>) -> Self::Future {
64+
let mut this = self.1.clone();
65+
if self.0.validate(&mut request).is_ok() {
66+
return OrFuture::with_result(Ok(request));
67+
}
68+
OrFuture::with_future(this.authorize(request))
69+
}
70+
}
71+
72+
#[pin_project::pin_project(project = KindProj)]
73+
pub enum Kind<Fut, Req, Resp> {
74+
QueryResult {
75+
#[pin]
76+
fut: Fut,
77+
},
78+
ReturnResult {
79+
validation_result: Option<Result<Request<Req>, Response<Resp>>>,
80+
},
81+
}
82+
83+
#[pin_project]
84+
pub struct OrFuture<Fut, Req, Resp> {
85+
#[pin]
86+
kind: Kind<Fut, Req, Resp>,
87+
}
88+
89+
impl<Fut, Req, Resp> OrFuture<Fut, Req, Resp> {
90+
fn with_result(validation_result: Result<Request<Req>, Response<Resp>>) -> Self {
91+
let validation_result = Some(validation_result);
92+
Self {
93+
kind: Kind::ReturnResult { validation_result },
94+
}
95+
}
96+
97+
fn with_future(fut: Fut) -> Self {
98+
Self {
99+
kind: Kind::QueryResult { fut },
100+
}
101+
}
102+
}
103+
104+
impl<Fut, Req, Resp> Future for OrFuture<Fut, Req, Resp>
105+
where
106+
Fut: Future<Output = Result<Request<Req>, Response<Resp>>>,
107+
{
108+
type Output = Result<Request<Req>, Response<Resp>>;
109+
110+
fn poll(
111+
self: Pin<&mut Self>,
112+
cx: &mut std::task::Context<'_>,
113+
) -> std::task::Poll<Self::Output> {
114+
let this = self.project();
115+
match this.kind.project() {
116+
KindProj::QueryResult { fut } => fut.poll(cx),
117+
KindProj::ReturnResult { validation_result } => {
118+
Poll::Ready(validation_result.take().expect("cannot poll twice"))
119+
}
120+
}
121+
}
122+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Validates Tap receipts
5+
//!
6+
//! This looks for a Context in the extensions of the request to inject
7+
//! as part of the checks.
8+
//!
9+
//! This also uses MetricLabels injected to
10+
11+
use std::{future::Future, sync::Arc};
12+
13+
use axum::{
14+
body::Body,
15+
http::{Request, Response},
16+
response::IntoResponse,
17+
};
18+
use tap_core::{
19+
manager::{adapters::ReceiptStore, Manager},
20+
receipt::{Context, SignedReceipt},
21+
};
22+
use tower_http::auth::AsyncAuthorizeRequest;
23+
24+
use crate::{error::IndexerServiceError, middleware::metrics::MetricLabels};
25+
26+
pub fn tap_receipt_authorize<T, B>(
27+
tap_manager: &'static Manager<T>,
28+
failed_receipt_metric: &'static prometheus::CounterVec,
29+
) -> impl AsyncAuthorizeRequest<
30+
B,
31+
RequestBody = B,
32+
ResponseBody = Body,
33+
Future = impl Future<Output = Result<Request<B>, Response<Body>>> + Send,
34+
> + Clone
35+
+ Send
36+
where
37+
T: ReceiptStore + Sync + Send,
38+
B: Send,
39+
{
40+
|request: Request<B>| {
41+
let receipt = request.extensions().get::<SignedReceipt>().cloned();
42+
// load labels from previous middlewares
43+
let labels = request.extensions().get::<MetricLabels>().cloned();
44+
// load context from previous middlewares
45+
let ctx = request.extensions().get::<Arc<Context>>().cloned();
46+
47+
async {
48+
let execute = || async {
49+
let receipt = receipt.ok_or(IndexerServiceError::ReceiptNotFound)?;
50+
51+
// Verify the receipt and store it in the database
52+
tap_manager
53+
.verify_and_store_receipt(&ctx.unwrap_or_default(), receipt)
54+
.await
55+
.inspect_err(|_| {
56+
if let Some(labels) = labels {
57+
failed_receipt_metric
58+
.with_label_values(&labels.get_labels())
59+
.inc()
60+
}
61+
})?;
62+
Ok::<_, IndexerServiceError>(request)
63+
};
64+
execute().await.map_err(|error| error.into_response())
65+
}
66+
}
67+
}

0 commit comments

Comments
 (0)