Skip to content

Commit e428cbe

Browse files
committed
refactor: add auth middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent a3dce07 commit e428cbe

File tree

6 files changed

+265
-0
lines changed

6 files changed

+265
-0
lines changed

crates/service/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ use thiserror::Error;
1515

1616
#[derive(Debug, Error)]
1717
pub enum IndexerServiceError {
18+
#[error("No Tap receipt was found in the request")]
19+
ReceiptNotFound,
20+
1821
#[error("Issues with provided receipt: {0}")]
1922
ReceiptError(#[from] tap_core::Error),
2023
#[error("No attestation signer found for allocation `{0}`")]
@@ -53,6 +56,7 @@ impl IntoResponse for IndexerServiceError {
5356
| InvalidFreeQueryAuthToken
5457
| EscrowAccount(_)
5558
| ProcessingError(_) => StatusCode::BAD_REQUEST,
59+
ReceiptNotFound => StatusCode::PAYMENT_REQUIRED,
5660
};
5761
tracing::error!(%self, "An IndexerServiceError occoured.");
5862
(

crates/service/src/middleware.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
mod auth;
45
mod inject_allocation;
56
mod inject_deployment;
67
mod inject_labels;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod bearer;
2+
mod or;
3+
mod tap;
4+
5+
pub use bearer::Bearer;
6+
pub use or::OrExt;
7+
pub use tap::tap_receipt_authorize;
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Bearer struct from tower-http but exposing the `new()` function
5+
//! to allow creation
6+
7+
use std::{fmt, marker::PhantomData};
8+
9+
use axum::http::{HeaderValue, Request, Response};
10+
use reqwest::{header, StatusCode};
11+
use tower_http::validate_request::ValidateRequest;
12+
13+
pub struct Bearer<ResBody> {
14+
header_value: HeaderValue,
15+
_ty: PhantomData<fn() -> ResBody>,
16+
}
17+
18+
impl<ResBody> Bearer<ResBody> {
19+
pub fn new(token: &str) -> Self
20+
where
21+
ResBody: Default,
22+
{
23+
Self {
24+
header_value: format!("Bearer {}", token)
25+
.parse()
26+
.expect("token is not a valid header value"),
27+
_ty: PhantomData,
28+
}
29+
}
30+
}
31+
32+
impl<ResBody> Clone for Bearer<ResBody> {
33+
fn clone(&self) -> Self {
34+
Self {
35+
header_value: self.header_value.clone(),
36+
_ty: PhantomData,
37+
}
38+
}
39+
}
40+
41+
impl<ResBody> fmt::Debug for Bearer<ResBody> {
42+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43+
f.debug_struct("Bearer")
44+
.field("header_value", &self.header_value)
45+
.finish()
46+
}
47+
}
48+
49+
impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
50+
where
51+
ResBody: Default,
52+
{
53+
type ResponseBody = ResBody;
54+
55+
fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
56+
match request.headers().get(header::AUTHORIZATION) {
57+
Some(actual) if actual == self.header_value => Ok(()),
58+
_ => {
59+
let mut res = Response::new(ResBody::default());
60+
*res.status_mut() = StatusCode::UNAUTHORIZED;
61+
Err(res)
62+
}
63+
}
64+
}
65+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Merge a ValidateRequest and an AsyncAuthorizeRequest
5+
//!
6+
//! executes a ValidateRequest returning the request if it succeeds
7+
//! or else, executed the future and return it
8+
9+
use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};
10+
11+
use axum::http::{Request, Response};
12+
use pin_project::pin_project;
13+
use tower_http::{auth::AsyncAuthorizeRequest, validate_request::ValidateRequest};
14+
15+
pub trait OrExt<T, B, Resp>: Sized {
16+
fn or(self, other: T) -> Or<Self, T, B, Resp>;
17+
}
18+
19+
impl<T, A, B, Resp, Fut> OrExt<A, B, Resp> for T
20+
where
21+
B: 'static + Send,
22+
Resp: 'static + Send,
23+
T: ValidateRequest<B, ResponseBody = Resp>,
24+
A: AsyncAuthorizeRequest<B, RequestBody = B, ResponseBody = Resp, Future = Fut>
25+
+ Clone
26+
+ 'static
27+
+ Send,
28+
Fut: Future<Output = Result<Request<B>, Response<Resp>>> + Send,
29+
{
30+
fn or(self, other: A) -> Or<Self, A, B, Resp> {
31+
Or(self, other, PhantomData)
32+
}
33+
}
34+
35+
pub struct Or<T, E, B, Resp>(T, E, PhantomData<fn(B) -> Resp>);
36+
37+
impl<T, E, B, Resp> Clone for Or<T, E, B, Resp>
38+
where
39+
T: Clone,
40+
E: Clone,
41+
{
42+
fn clone(&self) -> Self {
43+
Self(self.0.clone(), self.1.clone(), self.2)
44+
}
45+
}
46+
47+
impl<T, E, Req, Resp, Fut> AsyncAuthorizeRequest<Req> for Or<T, E, Req, Resp>
48+
where
49+
Req: 'static + Send,
50+
Resp: 'static + Send,
51+
T: ValidateRequest<Req, ResponseBody = Resp>,
52+
E: AsyncAuthorizeRequest<Req, RequestBody = Req, ResponseBody = Resp, Future = Fut>
53+
+ Clone
54+
+ 'static
55+
+ Send,
56+
Fut: Future<Output = Result<Request<Req>, Response<Resp>>> + Send,
57+
{
58+
type RequestBody = Req;
59+
type ResponseBody = Resp;
60+
61+
type Future = OrFuture<Fut, Req, Resp>;
62+
63+
fn authorize(&mut self, mut request: axum::http::Request<Req>) -> Self::Future {
64+
let mut this = self.1.clone();
65+
if self.0.validate(&mut request).is_ok() {
66+
return OrFuture::with_result(Ok(request));
67+
}
68+
OrFuture::with_future(this.authorize(request))
69+
}
70+
}
71+
72+
#[pin_project::pin_project(project = KindProj)]
73+
pub enum Kind<Fut, Req, Resp> {
74+
QueryResult {
75+
#[pin]
76+
fut: Fut,
77+
},
78+
ReturnResult {
79+
validation_result: Option<Result<Request<Req>, Response<Resp>>>,
80+
},
81+
}
82+
83+
#[pin_project]
84+
pub struct OrFuture<Fut, Req, Resp> {
85+
#[pin]
86+
kind: Kind<Fut, Req, Resp>,
87+
}
88+
89+
impl<Fut, Req, Resp> OrFuture<Fut, Req, Resp> {
90+
fn with_result(validation_result: Result<Request<Req>, Response<Resp>>) -> Self {
91+
let validation_result = Some(validation_result);
92+
Self {
93+
kind: Kind::ReturnResult { validation_result },
94+
}
95+
}
96+
97+
fn with_future(fut: Fut) -> Self {
98+
Self {
99+
kind: Kind::QueryResult { fut },
100+
}
101+
}
102+
}
103+
104+
impl<Fut, Req, Resp> Future for OrFuture<Fut, Req, Resp>
105+
where
106+
Fut: Future<Output = Result<Request<Req>, Response<Resp>>>,
107+
{
108+
type Output = Result<Request<Req>, Response<Resp>>;
109+
110+
fn poll(
111+
self: Pin<&mut Self>,
112+
cx: &mut std::task::Context<'_>,
113+
) -> std::task::Poll<Self::Output> {
114+
let this = self.project();
115+
match this.kind.project() {
116+
KindProj::QueryResult { fut } => fut.poll(cx),
117+
KindProj::ReturnResult { validation_result } => {
118+
Poll::Ready(validation_result.take().expect("cannot poll twice"))
119+
}
120+
}
121+
}
122+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
//! Validates Tap receipts
5+
//!
6+
//! This looks for a Context in the extensions of the request to inject
7+
//! as part of the checks.
8+
//!
9+
//! This also uses MetricLabels injected to
10+
11+
use std::{future::Future, sync::Arc};
12+
13+
use axum::{
14+
body::Body,
15+
http::{Request, Response},
16+
response::IntoResponse,
17+
};
18+
use tap_core::{
19+
manager::{adapters::ReceiptStore, Manager},
20+
receipt::{Context, SignedReceipt},
21+
};
22+
use tower_http::auth::AsyncAuthorizeRequest;
23+
24+
use crate::{error::IndexerServiceError, middleware::metrics::MetricLabels};
25+
26+
pub fn tap_receipt_authorize<T>(
27+
tap_manager: &'static Manager<T>,
28+
failed_receipt_metric: &'static prometheus::CounterVec,
29+
) -> impl AsyncAuthorizeRequest<
30+
Body,
31+
RequestBody = Body,
32+
ResponseBody = Body,
33+
Future = impl Future<Output = Result<Request<Body>, Response<Body>>> + Send,
34+
> + Clone
35+
+ Send
36+
where
37+
T: ReceiptStore + Sync + Send,
38+
{
39+
|request: Request<Body>| {
40+
let receipt = request.extensions().get::<SignedReceipt>().cloned();
41+
// load labels from previous middlewares
42+
let labels = request.extensions().get::<MetricLabels>().cloned();
43+
// load context from previous middlewares
44+
let ctx = request.extensions().get::<Arc<Context>>().cloned();
45+
46+
async {
47+
let execute = || async {
48+
let receipt = receipt.ok_or(IndexerServiceError::ReceiptNotFound)?;
49+
50+
// Verify the receipt and store it in the database
51+
tap_manager
52+
.verify_and_store_receipt(&ctx.unwrap_or_default(), receipt)
53+
.await
54+
.inspect_err(|_| {
55+
if let Some(labels) = labels {
56+
failed_receipt_metric
57+
.with_label_values(&labels.get_labels())
58+
.inc()
59+
}
60+
})?;
61+
Ok::<_, IndexerServiceError>(request)
62+
};
63+
execute().await.map_err(|error| error.into_response())
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)