Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/service/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use thiserror::Error;

#[derive(Debug, Error)]
pub enum IndexerServiceError {
#[error("No Tap receipt was found in the request")]
ReceiptNotFound,

#[error("Issues with provided receipt: {0}")]
ReceiptError(#[from] tap_core::Error),
#[error("No attestation signer found for allocation `{0}`")]
Expand Down Expand Up @@ -53,6 +56,7 @@ impl IntoResponse for IndexerServiceError {
| InvalidFreeQueryAuthToken
| EscrowAccount(_)
| ProcessingError(_) => StatusCode::BAD_REQUEST,
ReceiptNotFound => StatusCode::PAYMENT_REQUIRED,
};
tracing::error!(%self, "An IndexerServiceError occoured.");
(
Expand Down
1 change: 1 addition & 0 deletions crates/service/src/middleware.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

mod auth;
mod inject_allocation;
mod inject_deployment;
mod inject_labels;
Expand Down
10 changes: 10 additions & 0 deletions crates/service/src/middleware/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

mod bearer;
mod or;
mod tap;

pub use bearer::Bearer;
pub use or::OrExt;
pub use tap::tap_receipt_authorize;
67 changes: 67 additions & 0 deletions crates/service/src/middleware/auth/bearer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

//! Bearer struct from tower-http but exposing the `new()` function
//! to allow creation
//!
//! This code is from *tower-http*

use std::{fmt, marker::PhantomData};

use axum::http::{HeaderValue, Request, Response};
use reqwest::{header, StatusCode};
use tower_http::validate_request::ValidateRequest;

pub struct Bearer<ResBody> {
header_value: HeaderValue,
_ty: PhantomData<fn() -> ResBody>,
}

impl<ResBody> Bearer<ResBody> {
pub fn new(token: &str) -> Self
where
ResBody: Default,
{
Self {
header_value: format!("Bearer {}", token)
.parse()
.expect("token is not a valid header value"),
_ty: PhantomData,
}
}
}

impl<ResBody> Clone for Bearer<ResBody> {
fn clone(&self) -> Self {
Self {
header_value: self.header_value.clone(),
_ty: PhantomData,
}
}
}

impl<ResBody> fmt::Debug for Bearer<ResBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Bearer")
.field("header_value", &self.header_value)
.finish()
}
}

impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
where
ResBody: Default,
{
type ResponseBody = ResBody;

fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok(()),
_ => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
Err(res)
}
}
}
}
122 changes: 122 additions & 0 deletions crates/service/src/middleware/auth/or.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

//! Merge a ValidateRequest and an AsyncAuthorizeRequest
//!
//! executes a ValidateRequest returning the request if it succeeds
//! or else, executed the future and return it

use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll};

use axum::http::{Request, Response};
use pin_project::pin_project;
use tower_http::{auth::AsyncAuthorizeRequest, validate_request::ValidateRequest};

pub trait OrExt<T, B, Resp>: Sized {
fn or(self, other: T) -> Or<Self, T, B, Resp>;
}

impl<T, A, B, Resp, Fut> OrExt<A, B, Resp> for T
where
B: 'static + Send,
Resp: 'static + Send,
T: ValidateRequest<B, ResponseBody = Resp>,
A: AsyncAuthorizeRequest<B, RequestBody = B, ResponseBody = Resp, Future = Fut>
+ Clone
+ 'static
+ Send,
Fut: Future<Output = Result<Request<B>, Response<Resp>>> + Send,
{
fn or(self, other: A) -> Or<Self, A, B, Resp> {
Or(self, other, PhantomData)
}
}

pub struct Or<T, E, B, Resp>(T, E, PhantomData<fn(B) -> Resp>);

impl<T, E, B, Resp> Clone for Or<T, E, B, Resp>
where
T: Clone,
E: Clone,
{
fn clone(&self) -> Self {
Self(self.0.clone(), self.1.clone(), self.2)
}
}

impl<T, E, Req, Resp, Fut> AsyncAuthorizeRequest<Req> for Or<T, E, Req, Resp>
where
Req: 'static + Send,
Resp: 'static + Send,
T: ValidateRequest<Req, ResponseBody = Resp>,
E: AsyncAuthorizeRequest<Req, RequestBody = Req, ResponseBody = Resp, Future = Fut>
+ Clone
+ 'static
+ Send,
Fut: Future<Output = Result<Request<Req>, Response<Resp>>> + Send,
{
type RequestBody = Req;
type ResponseBody = Resp;

type Future = OrFuture<Fut, Req, Resp>;

fn authorize(&mut self, mut request: axum::http::Request<Req>) -> Self::Future {
let mut this = self.1.clone();
if self.0.validate(&mut request).is_ok() {
return OrFuture::with_result(Ok(request));
}
OrFuture::with_future(this.authorize(request))
}
}

#[pin_project::pin_project(project = KindProj)]
pub enum Kind<Fut, Req, Resp> {
QueryResult {
#[pin]
fut: Fut,
},
ReturnResult {
validation_result: Option<Result<Request<Req>, Response<Resp>>>,
},
}

#[pin_project]
pub struct OrFuture<Fut, Req, Resp> {
#[pin]
kind: Kind<Fut, Req, Resp>,
}

impl<Fut, Req, Resp> OrFuture<Fut, Req, Resp> {
fn with_result(validation_result: Result<Request<Req>, Response<Resp>>) -> Self {
let validation_result = Some(validation_result);
Self {
kind: Kind::ReturnResult { validation_result },
}
}

fn with_future(fut: Fut) -> Self {
Self {
kind: Kind::QueryResult { fut },
}
}
}

impl<Fut, Req, Resp> Future for OrFuture<Fut, Req, Resp>
where
Fut: Future<Output = Result<Request<Req>, Response<Resp>>>,
{
type Output = Result<Request<Req>, Response<Resp>>;

fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
match this.kind.project() {
KindProj::QueryResult { fut } => fut.poll(cx),
KindProj::ReturnResult { validation_result } => {
Poll::Ready(validation_result.take().expect("cannot poll twice"))
}
}
}
}
67 changes: 67 additions & 0 deletions crates/service/src/middleware/auth/tap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
// SPDX-License-Identifier: Apache-2.0

//! Validates Tap receipts
//!
//! This looks for a Context in the extensions of the request to inject
//! as part of the checks.
//!
//! This also uses MetricLabels injected to

use std::{future::Future, sync::Arc};

use axum::{
body::Body,
http::{Request, Response},
response::IntoResponse,
};
use tap_core::{
manager::{adapters::ReceiptStore, Manager},
receipt::{Context, SignedReceipt},
};
use tower_http::auth::AsyncAuthorizeRequest;

use crate::{error::IndexerServiceError, middleware::metrics::MetricLabels};

pub fn tap_receipt_authorize<T, B>(
tap_manager: &'static Manager<T>,
failed_receipt_metric: &'static prometheus::CounterVec,
) -> impl AsyncAuthorizeRequest<
B,
RequestBody = B,
ResponseBody = Body,
Future = impl Future<Output = Result<Request<B>, Response<Body>>> + Send,
> + Clone
+ Send
where
T: ReceiptStore + Sync + Send,
B: Send,
{
|request: Request<B>| {
let receipt = request.extensions().get::<SignedReceipt>().cloned();
// load labels from previous middlewares
let labels = request.extensions().get::<MetricLabels>().cloned();
// load context from previous middlewares
let ctx = request.extensions().get::<Arc<Context>>().cloned();

async {
let execute = || async {
let receipt = receipt.ok_or(IndexerServiceError::ReceiptNotFound)?;

// Verify the receipt and store it in the database
tap_manager
.verify_and_store_receipt(&ctx.unwrap_or_default(), receipt)
.await
.inspect_err(|_| {
if let Some(labels) = labels {
failed_receipt_metric
.with_label_values(&labels.get_labels())
.inc()
}
})?;
Ok::<_, IndexerServiceError>(request)
};
execute().await.map_err(|error| error.into_response())
}
}
}