diff --git a/.sqlx/query-6c05fc541bf0bb2af20fbe62747456055d5ebda5cb136d9d015f101ebbfe495f.json b/.sqlx/query-6c05fc541bf0bb2af20fbe62747456055d5ebda5cb136d9d015f101ebbfe495f.json new file mode 100644 index 000000000..bc62431d0 --- /dev/null +++ b/.sqlx/query-6c05fc541bf0bb2af20fbe62747456055d5ebda5cb136d9d015f101ebbfe495f.json @@ -0,0 +1,56 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT * FROM scalar_tap_receipts", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "signer_address", + "type_info": "Bpchar" + }, + { + "ordinal": 2, + "name": "signature", + "type_info": "Bytea" + }, + { + "ordinal": 3, + "name": "allocation_id", + "type_info": "Bpchar" + }, + { + "ordinal": 4, + "name": "timestamp_ns", + "type_info": "Numeric" + }, + { + "ordinal": 5, + "name": "nonce", + "type_info": "Numeric" + }, + { + "ordinal": 6, + "name": "value", + "type_info": "Numeric" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "6c05fc541bf0bb2af20fbe62747456055d5ebda5cb136d9d015f101ebbfe495f" +} diff --git a/Cargo.lock b/Cargo.lock index 66b2c2d02..522b249f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3599,6 +3599,7 @@ dependencies = [ "pin-project 1.1.7", "prometheus", "reqwest 0.12.9", + "rstest", "serde", "serde_json", "sqlx", @@ -3612,6 +3613,7 @@ dependencies = [ "tokio-util", "tower 0.5.1", "tower-http", + "tower-service", "tower-test", "tower_governor", "tracing", @@ -5330,6 +5332,12 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "rend" version = "0.4.2" @@ -5537,6 +5545,36 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version 0.4.1", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version 0.4.1", + "syn 2.0.79", + "unicode-ident", +] + [[package]] name = "ruint" version = "1.12.3" diff --git a/Cargo.toml b/Cargo.toml index 8eba38279..5b91fcddf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,3 +71,4 @@ thegraph-core = { git = "https://github.com/edgeandnode/toolshed", rev = "166353 thegraph-graphql-http = "0.2.0" graphql_client = { version = "0.14.0", features = ["reqwest-rustls"] } bip39 = "2.0.0" +rstest = "0.23.0" diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 1bfa00a99..626cd1bb7 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -59,7 +59,9 @@ pin-project = "1.1.7" [dev-dependencies] hex-literal = "0.4.1" test-assets = { path = "../test-assets" } +rstest.workspace = true tower-test = "0.4.0" +tower-service = "0.3.3" tokio-test = "0.4.4" [build-dependencies] diff --git a/crates/service/src/error.rs b/crates/service/src/error.rs index 98fcef42f..2180a7087 100644 --- a/crates/service/src/error.rs +++ b/crates/service/src/error.rs @@ -15,18 +15,22 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum IndexerServiceError { + #[error("No Tap receipt was found in the request")] + ReceiptNotFound, + #[error("Could not find deployment id")] + DeploymentIdNotFound, + #[error(transparent)] + AxumError(#[from] axum::Error), + + #[error(transparent)] + SerializationError(#[from] serde_json::Error), + #[error("Issues with provided receipt: {0}")] ReceiptError(#[from] tap_core::Error), #[error("No attestation signer found for allocation `{0}`")] NoSignerForAllocation(Address), - #[error("Invalid request body: {0}")] - InvalidRequest(anyhow::Error), #[error("Error while processing the request: {0}")] ProcessingError(SubgraphServiceError), - #[error("No valid receipt or free query auth token provided")] - Unauthorized, - #[error("Invalid free query auth token")] - InvalidFreeQueryAuthToken, #[error("Failed to sign attestation")] FailedToSignAttestation, @@ -44,15 +48,13 @@ impl IntoResponse for IndexerServiceError { } let status = match self { - Unauthorized => StatusCode::UNAUTHORIZED, - NoSignerForAllocation(_) | FailedToSignAttestation => StatusCode::INTERNAL_SERVER_ERROR, - ReceiptError(_) - | InvalidRequest(_) - | InvalidFreeQueryAuthToken - | EscrowAccount(_) - | ProcessingError(_) => StatusCode::BAD_REQUEST, + ReceiptError(_) | EscrowAccount(_) | ProcessingError(_) => StatusCode::BAD_REQUEST, + ReceiptNotFound => StatusCode::PAYMENT_REQUIRED, + DeploymentIdNotFound => StatusCode::INTERNAL_SERVER_ERROR, + AxumError(_) => StatusCode::BAD_REQUEST, + SerializationError(_) => StatusCode::BAD_REQUEST, }; tracing::error!(%self, "An IndexerServiceError occoured."); ( diff --git a/crates/service/src/middleware.rs b/crates/service/src/middleware.rs index 0c76f1853..968a97909 100644 --- a/crates/service/src/middleware.rs +++ b/crates/service/src/middleware.rs @@ -1,7 +1,9 @@ // Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. // SPDX-License-Identifier: Apache-2.0 +pub mod auth; mod inject_allocation; +mod inject_context; mod inject_deployment; mod inject_labels; mod inject_receipt; @@ -9,8 +11,9 @@ mod inject_sender; mod prometheus_metrics; pub use inject_allocation::{allocation_middleware, Allocation, AllocationState}; +pub use inject_context::context_middleware; pub use inject_deployment::deployment_middleware; pub use inject_labels::labels_middleware; pub use inject_receipt::receipt_middleware; -pub use inject_sender::{sender_middleware, Sender, SenderState}; +pub use inject_sender::{sender_middleware, SenderState}; pub use prometheus_metrics::PrometheusMetricsMiddlewareLayer; diff --git a/crates/service/src/middleware/auth.rs b/crates/service/src/middleware/auth.rs new file mode 100644 index 000000000..8e7e460e0 --- /dev/null +++ b/crates/service/src/middleware/auth.rs @@ -0,0 +1,138 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +mod bearer; +mod or; +mod tap; + +pub use bearer::Bearer; +pub use or::OrExt; +pub use tap::tap_receipt_authorize; + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use alloy::primitives::{address, Address}; + use axum::body::Body; + use axum::http::{Request, Response}; + use reqwest::{header, StatusCode}; + use sqlx::PgPool; + use tap_core::{manager::Manager, receipt::checks::CheckList}; + use tokio::time::sleep; + use tower::{Service, ServiceBuilder, ServiceExt}; + use tower_http::auth::AsyncRequireAuthorizationLayer; + + use crate::middleware::auth::{self, Bearer, OrExt}; + use crate::tap::IndexerTapContext; + use test_assets::{create_signed_receipt, TAP_EIP712_DOMAIN}; + + const ALLOCATION_ID: Address = address!("deadbeefcafebabedeadbeefcafebabedeadbeef"); + const BEARER_TOKEN: &str = "test"; + + async fn service( + pgpool: PgPool, + ) -> impl Service, Response = Response, Error = impl std::fmt::Debug> { + let context = IndexerTapContext::new(pgpool.clone(), TAP_EIP712_DOMAIN.clone()).await; + let tap_manager = Box::leak(Box::new(Manager::new( + TAP_EIP712_DOMAIN.clone(), + context, + CheckList::empty(), + ))); + + let registry = prometheus::Registry::new(); + let metric = Box::leak(Box::new( + prometheus::register_counter_vec_with_registry!( + "merge_checks_test", + "Failed queries to handler", + &["deployment"], + registry, + ) + .unwrap(), + )); + let free_query = Bearer::new(BEARER_TOKEN); + let tap_auth = auth::tap_receipt_authorize(tap_manager, metric); + let authorize_requests = free_query.or(tap_auth); + + let authorization_middleware = AsyncRequireAuthorizationLayer::new(authorize_requests); + + let mut service = ServiceBuilder::new() + .layer(authorization_middleware) + .service_fn(|_: Request| async { + Ok::<_, anyhow::Error>(Response::new(Body::default())) + }); + + service.ready().await.unwrap(); + service + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_composition_header_valid(pgpool: PgPool) { + let mut service = service(pgpool.clone()).await; + // should allow queries that contains the free token + // if the token does not match, return payment required + let mut req = Request::new(Default::default()); + req.headers_mut().insert( + header::AUTHORIZATION, + format!("Bearer {}", BEARER_TOKEN).parse().unwrap(), + ); + let res = service.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_composition_header_invalid(pgpool: PgPool) { + let mut service = service(pgpool.clone()).await; + + // if the token exists but is wrong, try the receipt + let mut req = Request::new(Default::default()); + req.headers_mut() + .insert(header::AUTHORIZATION, "Bearer wrongtoken".parse().unwrap()); + let res = service.call(req).await.unwrap(); + // we return the error from tap + assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED); + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_composition_with_receipt(pgpool: PgPool) { + let mut service = service(pgpool.clone()).await; + + let receipt = create_signed_receipt(ALLOCATION_ID, 1, 1, 1).await; + + // check with receipt + let mut req = Request::new(Default::default()); + req.extensions_mut().insert(receipt); + let res = service.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // verify receipts + if tokio::time::timeout(Duration::from_secs(1), async { + loop { + let result = sqlx::query!("SELECT * FROM scalar_tap_receipts") + .fetch_all(&pgpool) + .await + .unwrap(); + + if result.is_empty() { + sleep(Duration::from_millis(50)).await; + } else { + break; + } + } + }) + .await + .is_err() + { + panic!("Timeout assertion"); + } + } + + #[sqlx::test(migrations = "../../migrations")] + async fn test_composition_without_header_or_receipt(pgpool: PgPool) { + let mut service = service(pgpool.clone()).await; + // if it has neither, should return payment required + let req = Request::new(Default::default()); + let res = service.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED); + } +} diff --git a/crates/service/src/middleware/auth/bearer.rs b/crates/service/src/middleware/auth/bearer.rs new file mode 100644 index 000000000..cae0c51da --- /dev/null +++ b/crates/service/src/middleware/auth/bearer.rs @@ -0,0 +1,67 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +//! Bearer struct from tower-http but exposing the `new()` function +//! to allow creation +//! +//! This code is from *tower-http* + +use std::{fmt, marker::PhantomData}; + +use axum::http::{HeaderValue, Request, Response}; +use reqwest::{header, StatusCode}; +use tower_http::validate_request::ValidateRequest; + +pub struct Bearer { + header_value: HeaderValue, + _ty: PhantomData ResBody>, +} + +impl Bearer { + pub fn new(token: &str) -> Self + where + ResBody: Default, + { + Self { + header_value: format!("Bearer {}", token) + .parse() + .expect("token is not a valid header value"), + _ty: PhantomData, + } + } +} + +impl Clone for Bearer { + fn clone(&self) -> Self { + Self { + header_value: self.header_value.clone(), + _ty: PhantomData, + } + } +} + +impl fmt::Debug for Bearer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Bearer") + .field("header_value", &self.header_value) + .finish() + } +} + +impl ValidateRequest for Bearer +where + ResBody: Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, request: &mut Request) -> Result<(), Response> { + match request.headers().get(header::AUTHORIZATION) { + Some(actual) if actual == self.header_value => Ok(()), + _ => { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + Err(res) + } + } + } +} diff --git a/crates/service/src/middleware/auth/or.rs b/crates/service/src/middleware/auth/or.rs new file mode 100644 index 000000000..6f27ea244 --- /dev/null +++ b/crates/service/src/middleware/auth/or.rs @@ -0,0 +1,128 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +//! Merge a ValidateRequest and an AsyncAuthorizeRequest +//! +//! executes a ValidateRequest returning the request if it succeeds +//! or else, executes the future and return it + +use std::{future::Future, marker::PhantomData, pin::Pin, task::Poll}; + +use axum::http::{Request, Response}; +use pin_project::pin_project; +use tower_http::{auth::AsyncAuthorizeRequest, validate_request::ValidateRequest}; + +/// Extension that allows using a simple .or() function and return an Or struct +pub trait OrExt: Sized { + fn or(self, other: T) -> Or; +} + +impl OrExt for T +where + B: 'static + Send, + Resp: 'static + Send, + T: ValidateRequest, + A: AsyncAuthorizeRequest + + Clone + + 'static + + Send, + Fut: Future, Response>> + Send, +{ + fn or(self, other: A) -> Or { + Or(self, other, PhantomData) + } +} + +/// Or struct capable of implementing a ValidateRequest or an AsyncAuthorizeRequest +/// +/// Uses the first parameter to validate the request sync. +/// if it passes the check return the request to pass to the next middleware +/// if it doesn't pass, check the async future returning the result +pub struct Or(T, E, PhantomData Resp>); + +impl Clone for Or +where + T: Clone, + E: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone(), self.2) + } +} + +impl AsyncAuthorizeRequest for Or +where + Req: 'static + Send, + Resp: 'static + Send, + T: ValidateRequest, + E: AsyncAuthorizeRequest + + Clone + + 'static + + Send, + Fut: Future, Response>> + Send, +{ + type RequestBody = Req; + type ResponseBody = Resp; + + type Future = OrFuture; + + fn authorize(&mut self, mut request: axum::http::Request) -> Self::Future { + let mut this = self.1.clone(); + if self.0.validate(&mut request).is_ok() { + return OrFuture::with_result(Ok(request)); + } + OrFuture::with_future(this.authorize(request)) + } +} + +#[pin_project::pin_project(project = KindProj)] +pub enum Kind { + QueryResult { + #[pin] + fut: Fut, + }, + ReturnResult { + validation_result: Option, Response>>, + }, +} + +#[pin_project] +pub struct OrFuture { + #[pin] + kind: Kind, +} + +impl OrFuture { + fn with_result(validation_result: Result, Response>) -> Self { + let validation_result = Some(validation_result); + Self { + kind: Kind::ReturnResult { validation_result }, + } + } + + fn with_future(fut: Fut) -> Self { + Self { + kind: Kind::QueryResult { fut }, + } + } +} + +impl Future for OrFuture +where + Fut: Future, Response>>, +{ + type Output = Result, Response>; + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + match this.kind.project() { + KindProj::QueryResult { fut } => fut.poll(cx), + KindProj::ReturnResult { validation_result } => { + Poll::Ready(validation_result.take().expect("cannot poll twice")) + } + } + } +} diff --git a/crates/service/src/middleware/auth/tap.rs b/crates/service/src/middleware/auth/tap.rs new file mode 100644 index 000000000..cc06d3a9d --- /dev/null +++ b/crates/service/src/middleware/auth/tap.rs @@ -0,0 +1,254 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +//! Validates Tap receipts +//! +//! This looks for a Context in the extensions of the request to inject +//! as part of the checks. +//! +//! This also uses MetricLabels injected in the receipts to provide +//! metrics related to receipt check failure + +use std::{future::Future, sync::Arc}; + +use axum::{ + body::Body, + http::{Request, Response}, + response::IntoResponse, +}; +use tap_core::{ + manager::{adapters::ReceiptStore, Manager}, + receipt::{Context, SignedReceipt}, +}; +use tower_http::auth::AsyncAuthorizeRequest; + +use crate::{error::IndexerServiceError, middleware::prometheus_metrics::MetricLabels}; + +/// Middleware to verify and store TAP receipts +/// +/// It also optionally updates a failed receipt metric if Labels are provided +/// +/// Requires SignedReceipt, MetricLabels and Arc extensions +pub fn tap_receipt_authorize( + tap_manager: &'static Manager, + failed_receipt_metric: &'static prometheus::CounterVec, +) -> impl AsyncAuthorizeRequest< + B, + RequestBody = B, + ResponseBody = Body, + Future = impl Future, Response>> + Send, +> + Clone + + Send +where + T: ReceiptStore + Sync + Send, + B: Send, +{ + |request: Request| { + let receipt = request.extensions().get::().cloned(); + // load labels from previous middlewares + let labels = request.extensions().get::().cloned(); + // load context from previous middlewares + let ctx = request.extensions().get::>().cloned(); + + async { + let execute = || async { + let receipt = receipt.ok_or(IndexerServiceError::ReceiptNotFound)?; + // Verify the receipt and store it in the database + tap_manager + .verify_and_store_receipt(&ctx.unwrap_or_default(), receipt) + .await + .inspect_err(|_| { + if let Some(labels) = labels { + failed_receipt_metric + .with_label_values(&labels.get_labels()) + .inc() + } + })?; + Ok::<_, IndexerServiceError>(request) + }; + execute().await.map_err(|error| error.into_response()) + } + } +} + +#[cfg(test)] +mod tests { + + use core::panic; + use rstest::*; + use std::{sync::Arc, time::Duration}; + use tokio::time::sleep; + use tower::{Service, ServiceBuilder, ServiceExt}; + + use alloy::primitives::{address, Address}; + use axum::{ + body::Body, + http::{Request, Response}, + }; + use prometheus::core::Collector; + use reqwest::StatusCode; + use sqlx::PgPool; + use tap_core::{ + manager::Manager, + receipt::{ + checks::{Check, CheckError, CheckList, CheckResult}, + state::Checking, + ReceiptWithState, + }, + }; + use test_assets::{create_signed_receipt, TAP_EIP712_DOMAIN}; + use tower_http::auth::AsyncRequireAuthorizationLayer; + + use crate::{ + middleware::{ + auth::tap_receipt_authorize, + prometheus_metrics::{MetricLabelProvider, MetricLabels}, + }, + tap::IndexerTapContext, + }; + + const ALLOCATION_ID: Address = address!("deadbeefcafebabedeadbeefcafebabedeadbeef"); + + #[fixture] + fn metric() -> &'static prometheus::CounterVec { + let registry = prometheus::Registry::new(); + let metric = Box::leak(Box::new( + prometheus::register_counter_vec_with_registry!( + "tap_middleware_test", + "Failed queries to handler", + &["deployment"], + registry, + ) + .unwrap(), + )); + metric + } + + const FAILED_NONCE: u64 = 99; + + async fn service( + metric: &'static prometheus::CounterVec, + pgpool: PgPool, + ) -> impl Service, Response = Response, Error = impl std::fmt::Debug> { + let context = IndexerTapContext::new(pgpool, TAP_EIP712_DOMAIN.clone()).await; + + struct MyCheck; + #[async_trait::async_trait] + impl Check for MyCheck { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { + if receipt.signed_receipt().message.nonce == FAILED_NONCE { + Err(CheckError::Failed(anyhow::anyhow!("Failed"))) + } else { + Ok(()) + } + } + } + + let manager = Box::leak(Box::new(Manager::new( + TAP_EIP712_DOMAIN.clone(), + context, + CheckList::new(vec![Arc::new(MyCheck)]), + ))); + let tap_auth = tap_receipt_authorize(manager, metric); + let authorization_middleware = AsyncRequireAuthorizationLayer::new(tap_auth); + + let mut service = ServiceBuilder::new() + .layer(authorization_middleware) + .service_fn(|_: Request| async { + Ok::<_, anyhow::Error>(Response::new(Body::default())) + }); + + service.ready().await.unwrap(); + service + } + + #[rstest] + #[sqlx::test(migrations = "../../migrations")] + async fn test_tap_valid_receipt( + metric: &'static prometheus::CounterVec, + #[ignore] pgpool: PgPool, + ) { + let mut service = service(metric, pgpool.clone()).await; + + let receipt = create_signed_receipt(ALLOCATION_ID, 1, 1, 1).await; + + // check with receipt + let mut req = Request::new(Body::default()); + req.extensions_mut().insert(receipt); + let res = service.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // verify receipts + if tokio::time::timeout(Duration::from_secs(1), async { + loop { + let result = sqlx::query!("SELECT * FROM scalar_tap_receipts") + .fetch_all(&pgpool) + .await + .unwrap(); + + if result.is_empty() { + sleep(Duration::from_millis(50)).await; + } else { + break; + } + } + }) + .await + .is_err() + { + panic!("Timeout assertion"); + } + } + + #[rstest] + #[sqlx::test(migrations = "../../migrations")] + async fn test_invalid_receipt_with_failed_metric( + metric: &'static prometheus::CounterVec, + #[ignore] pgpool: PgPool, + ) { + let mut service = service(metric, pgpool.clone()).await; + // if it fails tap receipt, should return failed to process payment + tap message + + assert_eq!(metric.collect().first().unwrap().get_metric().len(), 0); + + struct TestLabel; + impl MetricLabelProvider for TestLabel { + fn get_labels(&self) -> Vec<&str> { + vec!["label1"] + } + } + + // default labels, all empty + let labels: MetricLabels = Arc::new(TestLabel); + + let mut receipt = create_signed_receipt(ALLOCATION_ID, 1, 1, 1).await; + // change the nonce to make the receipt invalid + receipt.message.nonce = FAILED_NONCE; + let mut req = Request::new(Body::default()); + req.extensions_mut().insert(receipt); + req.extensions_mut().insert(labels); + let response = service.call(req); + + assert_eq!(response.await.unwrap().status(), StatusCode::BAD_REQUEST); + + assert_eq!(metric.collect().first().unwrap().get_metric().len(), 1); + } + + #[rstest] + #[sqlx::test(migrations = "../../migrations")] + async fn test_tap_missing_signed_receipt( + metric: &'static prometheus::CounterVec, + #[ignore] pgpool: PgPool, + ) { + let mut service = service(metric, pgpool.clone()).await; + // if it doesnt contain the signed receipt + // should return payment required + let req = Request::new(Body::default()); + let res = service.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED); + } +} diff --git a/crates/service/src/middleware/inject_context.rs b/crates/service/src/middleware/inject_context.rs new file mode 100644 index 000000000..26b73130d --- /dev/null +++ b/crates/service/src/middleware/inject_context.rs @@ -0,0 +1,126 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +//! Injects tap context to be used by the checks +//! +//! Requires Deployment Id extension to available + +use serde_json::value::RawValue; +use std::sync::Arc; + +use axum::{ + body::to_bytes, + extract::{Path, Request}, + middleware::Next, + response::Response, + RequestExt, +}; +use tap_core::receipt::Context; +use thegraph_core::DeploymentId; + +use crate::{error::IndexerServiceError, tap::AgoraQuery}; + +/// Graphql query body to be decoded and passed to agora context +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct QueryBody { + query: String, + variables: Option>, +} + +/// Injects tap context in the extensions to be used by tap_receipt_authorize +pub async fn context_middleware( + mut request: Request, + next: Next, +) -> Result { + let deployment_id = match request.extensions().get::() { + Some(deployment) => *deployment, + None => match request.extract_parts::>().await { + Ok(Path(deployment)) => deployment, + Err(_) => return Err(IndexerServiceError::DeploymentIdNotFound), + }, + }; + + let (mut parts, body) = request.into_parts(); + let bytes = to_bytes(body, usize::MAX).await?; + let query_body: QueryBody = serde_json::from_slice(&bytes)?; + + let variables = query_body + .variables + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(); + + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: query_body.query.clone(), + variables, + }); + parts.extensions.insert(Arc::new(ctx)); + let request = Request::from_parts(parts, bytes.into()); + Ok(next.run(request).await) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use axum::{ + body::Body, + http::{Extensions, Request}, + middleware::from_fn, + routing::get, + Router, + }; + use reqwest::StatusCode; + use tap_core::receipt::Context; + use test_assets::ESCROW_SUBGRAPH_DEPLOYMENT; + use tower::ServiceExt; + + use crate::{ + middleware::inject_context::{context_middleware, QueryBody}, + tap::AgoraQuery, + }; + + #[tokio::test] + async fn test_context_middleware() { + let middleware = from_fn(context_middleware); + let deployment = *ESCROW_SUBGRAPH_DEPLOYMENT; + let query_body = QueryBody { + query: "hello".to_string(), + variables: None, + }; + let body = serde_json::to_string(&query_body).unwrap(); + + let handle = move |extensions: Extensions| async move { + let ctx = extensions + .get::>() + .expect("Should contain context"); + let agora = ctx.get::().expect("should contain agora query"); + assert_eq!(agora.deployment_id, deployment); + assert_eq!(agora.query, query_body.query); + + let variables = query_body + .variables + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(); + assert_eq!(agora.variables, variables); + Body::empty() + }; + + let app = Router::new().route("/", get(handle)).layer(middleware); + + let res = app + .oneshot( + Request::builder() + .uri("/") + .extension(deployment) + .body(body) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } +} diff --git a/crates/service/src/middleware/inject_receipt.rs b/crates/service/src/middleware/inject_receipt.rs index db2491c82..f8e0507c8 100644 --- a/crates/service/src/middleware/inject_receipt.rs +++ b/crates/service/src/middleware/inject_receipt.rs @@ -14,10 +14,10 @@ use crate::service::TapReceipt; /// /// This is useful to not deserialize multiple times the same receipt pub async fn receipt_middleware(mut request: Request, next: Next) -> Response { - if let Ok(TypedHeader(receipt)) = request.extract_parts::>().await { - if let Some(receipt) = receipt.into_signed_receipt() { - request.extensions_mut().insert(receipt); - } + if let Ok(TypedHeader(TapReceipt(receipt))) = + request.extract_parts::>().await + { + request.extensions_mut().insert(receipt); } next.run(request).await } diff --git a/crates/service/src/routes/request_handler.rs b/crates/service/src/routes/request_handler.rs index 3a84fce5c..7ac2fad0d 100644 --- a/crates/service/src/routes/request_handler.rs +++ b/crates/service/src/routes/request_handler.rs @@ -3,87 +3,26 @@ use std::sync::Arc; -use crate::{ - error::IndexerServiceError, - metrics::FAILED_RECEIPT, - middleware::{Allocation, Sender}, - tap::AgoraQuery, -}; +use crate::{error::IndexerServiceError, middleware::Allocation}; use axum::{ extract::{Path, State}, - http::HeaderMap, response::IntoResponse, Extension, }; -use axum_extra::TypedHeader; use reqwest::StatusCode; -use serde_json::value::RawValue; -use tap_core::receipt::Context; use thegraph_core::DeploymentId; use tracing::trace; -use crate::service::{AttestationOutput, IndexerServiceResponse, IndexerServiceState, TapReceipt}; +use crate::service::{AttestationOutput, IndexerServiceResponse, IndexerServiceState}; pub async fn request_handler( Path(manifest_id): Path, - TypedHeader(receipt): TypedHeader, - Extension(Sender(sender)): Extension, Extension(Allocation(allocation_id)): Extension, State(state): State>, - headers: HeaderMap, req: String, ) -> Result { trace!("Handling request for deployment `{manifest_id}`"); - let request: QueryBody = - serde_json::from_str(&req).map_err(|e| IndexerServiceError::InvalidRequest(e.into()))?; - - if let Some(receipt) = receipt.into_signed_receipt() { - let variables = request - .variables - .as_ref() - .map(ToString::to_string) - .unwrap_or_default(); - let mut ctx = Context::new(); - ctx.insert(AgoraQuery { - deployment_id: manifest_id, - query: request.query.clone(), - variables, - }); - - // Verify the receipt and store it in the database - state - .tap_manager - .verify_and_store_receipt(&ctx, receipt) - .await - .inspect_err(|_| { - FAILED_RECEIPT - .with_label_values(&[ - &manifest_id.to_string(), - &allocation_id.to_string(), - &sender.to_string(), - ]) - .inc() - }) - .map_err(IndexerServiceError::ReceiptError)?; - } else { - match headers - .get("authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.strip_prefix("Bearer ")) - .map(|s| s.to_string()) - { - None => return Err(IndexerServiceError::Unauthorized), - Some(ref token) => { - if Some(token) != state.config.service.free_query_auth_token.as_ref() { - return Err(IndexerServiceError::InvalidFreeQueryAuthToken); - } - } - } - - trace!(?manifest_id, "New free query"); - } - // Check if we have an attestation signer for the allocation the receipt was created for let signer = state .attestation_signers @@ -94,7 +33,7 @@ pub async fn request_handler( let response = state .service_impl - .process_request(manifest_id, request) + .process_request(manifest_id, &req) .await .map_err(IndexerServiceError::ProcessingError)?; @@ -112,9 +51,3 @@ pub async fn request_handler( Ok((StatusCode::OK, response)) } - -#[derive(Debug, serde::Deserialize, serde::Serialize)] -pub struct QueryBody { - pub query: String, - pub variables: Option>, -} diff --git a/crates/service/src/service.rs b/crates/service/src/service.rs index 0a59cc5aa..45ed9ae28 100644 --- a/crates/service/src/service.rs +++ b/crates/service/src/service.rs @@ -98,7 +98,7 @@ impl SubgraphService { pub async fn process_request( &self, deployment: DeploymentId, - request: Request, + request: &Request, ) -> Result { let deployment_url = self .state @@ -110,7 +110,7 @@ impl SubgraphService { .state .graph_node_client .post(deployment_url) - .json(&request) + .json(request) .send() .await .map_err(SubgraphServiceError::QueryForwardingError)?; diff --git a/crates/service/src/service/indexer_service.rs b/crates/service/src/service/indexer_service.rs index c50cf8b3a..a8f8b3ca9 100644 --- a/crates/service/src/service/indexer_service.rs +++ b/crates/service/src/service/indexer_service.rs @@ -2,9 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use anyhow; -use axum::extract::MatchedPath; -use axum::extract::Request as ExtractRequest; use axum::{ + extract::{MatchedPath, Request as ExtractRequest}, http::{Method, Request}, middleware::{from_fn, from_fn_with_state}, response::IntoResponse, @@ -26,20 +25,24 @@ use std::{ }; use tap_core::{manager::Manager, receipt::checks::CheckList, tap_eip712_domain}; use thegraph_core::{Address, Attestation}; -use tokio::net::TcpListener; -use tokio::signal; -use tokio::sync::watch::Receiver; +use tokio::{net::TcpListener, signal, sync::watch::Receiver}; use tower::ServiceBuilder; use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; -use tower_http::validate_request::ValidateRequestHeaderLayer; -use tower_http::{cors, cors::CorsLayer, normalize_path::NormalizePath, trace::TraceLayer}; -use tracing::warn; -use tracing::{error, info, info_span}; +use tower_http::{ + auth::AsyncRequireAuthorizationLayer, + cors::{self, CorsLayer}, + normalize_path::NormalizePath, + trace::TraceLayer, + validate_request::ValidateRequestHeaderLayer, +}; +use tracing::{error, info, info_span, warn}; use crate::{ - metrics::{HANDLER_FAILURE, HANDLER_HISTOGRAM}, + metrics::{FAILED_RECEIPT, HANDLER_FAILURE, HANDLER_HISTOGRAM}, middleware::{ - allocation_middleware, deployment_middleware, labels_middleware, receipt_middleware, + allocation_middleware, + auth::{self, Bearer, OrExt}, + context_middleware, deployment_middleware, labels_middleware, receipt_middleware, sender_middleware, AllocationState, PrometheusMetricsMiddlewareLayer, SenderState, }, routes::{health, request_handler, static_subgraph_request_handler}, @@ -96,7 +99,6 @@ pub struct IndexerServiceOptions { pub struct IndexerServiceState { pub config: Config, pub attestation_signers: Receiver>, - pub tap_manager: Manager, pub service_impl: SubgraphService, } @@ -266,16 +268,15 @@ pub async fn run(options: IndexerServiceOptions) -> Result<(), anyhow::Error> { ) .await; - let tap_manager = Manager::new( + let tap_manager = Box::leak(Box::new(Manager::new( domain_separator.clone(), indexer_context, CheckList::new(checks), - ); + ))); let state = Arc::new(IndexerServiceState { config: options.config.clone(), attestation_signers, - tap_manager, service_impl: options.service_impl, }); @@ -361,6 +362,22 @@ pub async fn run(options: IndexerServiceOptions) -> Result<(), anyhow::Error> { misc_routes = misc_routes.with_state(state.clone()); + let mut request_handler_route = post(request_handler); + + // inject auth + let failed_receipt_metric = Box::leak(Box::new(FAILED_RECEIPT.clone())); + let tap_auth = auth::tap_receipt_authorize(tap_manager, failed_receipt_metric); + + if let Some(free_auth_token) = &options.config.service.serve_auth_token { + let free_query = Bearer::new(free_auth_token); + let result = free_query.or(tap_auth); + let auth_layer = AsyncRequireAuthorizationLayer::new(result); + request_handler_route = request_handler_route.layer(auth_layer); + } else { + let auth_layer = AsyncRequireAuthorizationLayer::new(tap_auth); + request_handler_route = request_handler_route.layer(auth_layer); + } + let deployment_to_allocation = deployment_to_allocation(allocations); let allocation_state = AllocationState { deployment_to_allocation, @@ -385,7 +402,11 @@ pub async fn run(options: IndexerServiceOptions) -> Result<(), anyhow::Error> { .layer(PrometheusMetricsMiddlewareLayer::new( HANDLER_HISTOGRAM.clone(), HANDLER_FAILURE.clone(), - )); + )) + // tap context + .layer(from_fn(context_middleware)); + + request_handler_route = request_handler_route.layer(service_builder); let data_routes = Router::new() .route( @@ -393,7 +414,7 @@ pub async fn run(options: IndexerServiceOptions) -> Result<(), anyhow::Error> { .join(format!("{}/id/:id", options.url_namespace)) .to_str() .expect("Failed to set up `/{url_namespace}/id/:id` route"), - post(request_handler).route_layer(service_builder), + request_handler_route, ) .with_state(state.clone()); diff --git a/crates/service/src/service/tap_receipt_header.rs b/crates/service/src/service/tap_receipt_header.rs index d5937cb45..fcb8af52e 100644 --- a/crates/service/src/service/tap_receipt_header.rs +++ b/crates/service/src/service/tap_receipt_header.rs @@ -1,29 +1,13 @@ // Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. // SPDX-License-Identifier: Apache-2.0 -use std::ops::Deref; - use axum_extra::headers::{self, Header, HeaderName, HeaderValue}; use lazy_static::lazy_static; use prometheus::{register_counter, Counter}; use tap_core::receipt::SignedReceipt; #[derive(Debug, PartialEq)] -pub struct TapReceipt(Option); - -impl TapReceipt { - pub fn into_signed_receipt(self) -> Option { - self.0 - } -} - -impl Deref for TapReceipt { - type Target = Option; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} +pub struct TapReceipt(pub SignedReceipt); lazy_static! { static ref TAP_RECEIPT: HeaderName = HeaderName::from_static("tap-receipt"); @@ -42,14 +26,12 @@ impl Header for TapReceipt { { let mut execute = || { let value = values.next(); - let raw_receipt = value - .map(|value| value.to_str()) - .transpose() - .map_err(|_| headers::Error::invalid())?; - let parsed_receipt = raw_receipt - .map(serde_json::from_str) - .transpose() + let raw_receipt = value.ok_or(headers::Error::invalid())?; + let raw_receipt = raw_receipt + .to_str() .map_err(|_| headers::Error::invalid())?; + let parsed_receipt = + serde_json::from_str(raw_receipt).map_err(|_| headers::Error::invalid())?; Ok(TapReceipt(parsed_receipt)) }; execute().inspect_err(|_| TAP_RECEIPT_INVALID.inc()) @@ -86,7 +68,7 @@ mod test { let decoded_receipt = TapReceipt::decode(&mut header_values.into_iter()) .expect("tap receipt header value should be valid"); - assert_eq!(decoded_receipt, TapReceipt(Some(original_receipt.clone()))); + assert_eq!(decoded_receipt, TapReceipt(original_receipt)); } #[test]