diff --git a/Cargo.lock b/Cargo.lock index d828dd8d2..43720d595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2004,6 +2004,8 @@ dependencies = [ "mimalloc", "moka", "ntex", + "ntex-service", + "ntex-util", "rand 0.9.2", "regex-automata", "reqwest", @@ -2046,12 +2048,14 @@ name = "hive-router-plan-executor" version = "6.0.1" dependencies = [ "ahash", + "arc-swap", "async-trait", "bumpalo", "bytes", "criterion", "dashmap", "futures", + "futures-util", "graphql-parser", "graphql-tools", "hive-router-config", @@ -2064,9 +2068,13 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "multer", + "ntex", "ntex-http", "ordered-float", + "redis", "regex-automata", + "reqwest", "ryu", "serde", "sonic-rs", @@ -2797,6 +2805,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4217,6 +4235,22 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redis" +version = "0.32.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014cc767fefab6a3e798ca45112bccad9c6e0e218fbd49720042716c73cfef44" +dependencies = [ + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2 0.6.1", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4328,6 +4362,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", + "futures-util", "h2", "http", "http-body", @@ -4339,6 +4374,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -4902,6 +4938,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" @@ -5804,6 +5846,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/bin/router/Cargo.toml b/bin/router/Cargo.toml index 4969b0a46..15d5e2051 100644 --- a/bin/router/Cargo.toml +++ b/bin/router/Cargo.toml @@ -53,3 +53,5 @@ tokio-util = "0.7.16" cookie = "0.18.1" regex-automata = "0.4.10" arc-swap = "1.7.1" +ntex-util = "2.15.0" +ntex-service = "3.5.0" \ No newline at end of file diff --git a/bin/router/src/jwt/mod.rs b/bin/router/src/jwt/mod.rs index d95854b1c..d7804e156 100644 --- a/bin/router/src/jwt/mod.rs +++ b/bin/router/src/jwt/mod.rs @@ -265,26 +265,27 @@ impl JwtAuthRuntime { Ok(token_data) } - pub fn validate_request(&self, request: &mut HttpRequest) -> Result<(), JwtError> { + pub fn validate_request( + &self, + request: &HttpRequest, + ) -> Result, JwtError> { let valid_jwks = self.jwks.all(); match self.authenticate(&valid_jwks, request) { - Ok((token_payload, maybe_token_prefix, token)) => { - request.extensions_mut().insert(JwtRequestContext { - token_payload, - token_raw: token, - token_prefix: maybe_token_prefix, - }); - } + Ok((token_payload, maybe_token_prefix, token)) => Ok(Some(JwtRequestContext { + token_payload, + token_raw: token, + token_prefix: maybe_token_prefix, + })), Err(e) => { warn!("jwt token error: {:?}", e); if self.config.require_authentication.is_some_and(|v| v) { return Err(e); } + + Ok(None) } } - - Ok(()) } } diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 6a3f7f5c0..5e5a0e353 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -4,6 +4,7 @@ mod http_utils; mod jwt; mod logger; mod pipeline; +mod plugins; mod schema_state; mod shared_state; mod supergraph; @@ -19,7 +20,12 @@ use crate::{ }, jwt::JwtAuthRuntime, logger::configure_logging, - pipeline::graphql_request_handler, + pipeline::{ + error::PipelineError, + graphql_request_handler, + header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + }, + plugins::plugins_service::PluginService, }; pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; @@ -33,7 +39,7 @@ use ntex::{ use tracing::{info, warn}; async fn graphql_endpoint_handler( - mut request: HttpRequest, + req: HttpRequest, body_bytes: Bytes, schema_state: web::types::State>, app_state: web::types::State>, @@ -45,26 +51,32 @@ async fn graphql_endpoint_handler( if let Some(early_response) = app_state .cors_runtime .as_ref() - .and_then(|cors| cors.get_early_response(&request)) + .and_then(|cors| cors.get_early_response(&req)) { return early_response; } - let mut res = graphql_request_handler( - &mut request, + let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); + + let mut response = match graphql_request_handler( + &req, body_bytes, supergraph, app_state.get_ref(), schema_state.get_ref(), ) - .await; + .await + { + Ok(response_with_req) => response_with_req, + Err(error) => return PipelineError { accept_ok, error }.into(), + }; // Apply CORS headers to the final response if CORS is configured. if let Some(cors) = app_state.cors_runtime.as_ref() { - cors.set_headers(&request, res.headers_mut()); + cors.set_headers(&req, response.headers_mut()); } - res + response } else { warn!("No supergraph available yet, unable to process request"); @@ -86,6 +98,7 @@ pub async fn router_entrypoint() -> Result<(), Box> { let maybe_error = web::HttpServer::new(move || { web::App::new() + .wrap(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) .configure(configure_ntex_app) @@ -112,10 +125,17 @@ pub async fn configure_app_from_config( }; let router_config_arc = Arc::new(router_config); - let schema_state = - SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone()).await?; + let shared_state = Arc::new(RouterSharedState::new( + router_config_arc.clone(), + jwt_runtime, + )?); + let schema_state = SchemaState::new_from_config( + bg_tasks_manager, + router_config_arc.clone(), + shared_state.clone(), + ) + .await?; let schema_state_arc = Arc::new(schema_state); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc, jwt_runtime)?); Ok((shared_state, schema_state_arc)) } diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index 8c472695e..ab5759b5e 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use std::sync::Arc; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::variables::collect_variables; use hive_router_query_planner::state::supergraph_state::OperationKind; use http::Method; @@ -8,10 +9,8 @@ use ntex::web::HttpRequest; use sonic_rs::Value; use tracing::{error, trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; #[derive(Clone, Debug)] pub struct CoerceVariablesPayload { @@ -22,22 +21,22 @@ pub struct CoerceVariablesPayload { pub fn coerce_request_variables( req: &HttpRequest, supergraph: &SupergraphData, - execution_params: &mut ExecutionRequest, - normalized_operation: &Arc, -) -> Result { + graphql_params: &mut GraphQLParams, + normalized_operation: &GraphQLNormalizationPayload, +) -> Result { if req.method() == Method::GET { if let Some(OperationKind::Mutation) = normalized_operation.operation_for_plan.operation_kind { error!("Mutation is not allowed over GET, stopping"); - return Err(req.new_pipeline_error(PipelineErrorVariant::MutationNotAllowedOverHttpGet)); + return Err(PipelineErrorVariant::MutationNotAllowedOverHttpGet); } } match collect_variables( &normalized_operation.operation_for_plan, - &mut execution_params.variables, + &mut graphql_params.variables, &supergraph.metadata, ) { Ok(values) => { @@ -55,7 +54,7 @@ pub fn coerce_request_variables( "failed to collect variables from incoming request: {}", err_msg ); - Err(req.new_pipeline_error(PipelineErrorVariant::VariablesCoercionError(err_msg))) + Err(PipelineErrorVariant::VariablesCoercionError(err_msg)) } } } diff --git a/bin/router/src/pipeline/csrf_prevention.rs b/bin/router/src/pipeline/csrf_prevention.rs index 51561dd99..37c063b09 100644 --- a/bin/router/src/pipeline/csrf_prevention.rs +++ b/bin/router/src/pipeline/csrf_prevention.rs @@ -1,7 +1,7 @@ use hive_router_config::csrf::CSRFPreventionConfig; use ntex::web::HttpRequest; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; // NON_PREFLIGHTED_CONTENT_TYPES are content types that do not require a preflight // OPTIONS request. These are content types that are considered "simple" by the CORS @@ -15,9 +15,9 @@ const NON_PREFLIGHTED_CONTENT_TYPES: [&str; 3] = [ #[inline] pub fn perform_csrf_prevention( - req: &mut HttpRequest, + req: &HttpRequest, csrf_config: &CSRFPreventionConfig, -) -> Result<(), PipelineError> { +) -> Result<(), PipelineErrorVariant> { // If CSRF prevention is not configured or disabled, skip the checks. if !csrf_config.enabled || csrf_config.required_headers.is_empty() { return Ok(()); @@ -39,7 +39,7 @@ pub fn perform_csrf_prevention( if has_required_header { Ok(()) } else { - Err(req.new_pipeline_error(PipelineErrorVariant::CsrfPreventionFailed)) + Err(PipelineErrorVariant::CsrfPreventionFailed) } } diff --git a/bin/router/src/pipeline/deserialize_graphql_params.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs new file mode 100644 index 000000000..b22b18a3a --- /dev/null +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use http::Method; +use ntex::util::Bytes; +use ntex::web::types::Query; +use ntex::web::HttpRequest; +use tracing::{trace, warn}; + +use crate::pipeline::error::PipelineErrorVariant; +use crate::pipeline::header::AssertRequestJson; + +#[derive(serde::Deserialize, Debug)] +struct GETQueryParams { + pub query: Option, + #[serde(rename = "camelCase")] + pub operation_name: Option, + pub variables: Option, + pub extensions: Option, +} + +impl TryInto for GETQueryParams { + type Error = PipelineErrorVariant; + + fn try_into(self) -> Result { + let variables = match self.variables.as_deref() { + Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { + Ok(vars) => vars, + Err(e) => { + return Err(PipelineErrorVariant::FailedToParseVariables(e)); + } + }, + _ => HashMap::new(), + }; + + let extensions = match self.extensions.as_deref() { + Some(e_str) if !e_str.is_empty() => match sonic_rs::from_str(e_str) { + Ok(exts) => Some(exts), + Err(e) => { + return Err(PipelineErrorVariant::FailedToParseExtensions(e)); + } + }, + _ => None, + }; + + let execution_request = GraphQLParams { + query: self.query, + operation_name: self.operation_name, + variables, + extensions, + }; + + Ok(execution_request) + } +} + +pub trait GetQueryStr { + fn get_query(&self) -> Result<&str, PipelineErrorVariant>; +} + +impl GetQueryStr for GraphQLParams { + fn get_query(&self) -> Result<&str, PipelineErrorVariant> { + self.query + .as_deref() + .ok_or(PipelineErrorVariant::GetMissingQueryParam("query")) + } +} + +#[inline] +pub fn deserialize_graphql_params( + req: &HttpRequest, + body_bytes: Bytes, +) -> Result { + let http_method = req.method(); + let graphql_params: GraphQLParams = match *http_method { + Method::GET => { + trace!("processing GET GraphQL operation"); + let query_params_str = req + .uri() + .query() + .ok_or_else(|| PipelineErrorVariant::GetInvalidQueryParams)?; + let query_params = Query::::from_query(query_params_str) + .map_err(PipelineErrorVariant::GetUnprocessableQueryParams)? + .0; + + trace!("parsed GET query params: {:?}", query_params); + + query_params.try_into()? + } + Method::POST => { + trace!("Processing POST GraphQL request"); + + req.assert_json_content_type()?; + + let execution_request = unsafe { + sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { + warn!("Failed to parse body: {}", e); + PipelineErrorVariant::FailedToParseBody(e) + })? + }; + + execution_request + } + _ => { + warn!("unsupported HTTP method: {}", http_method); + + return Err(PipelineErrorVariant::UnsupportedHttpMethod( + http_method.to_owned(), + )); + } + }; + + Ok(graphql_params) +} diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index eec36ea76..1d856d62e 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -10,15 +10,12 @@ use hive_router_query_planner::{ }; use http::{HeaderName, Method, StatusCode}; use ntex::{ - http::ResponseBuilder, - web::{self, error::QueryPayloadError, HttpRequest}, + http::{Response, ResponseBuilder}, + web::error::QueryPayloadError, }; use serde::{Deserialize, Serialize}; -use crate::pipeline::{ - header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, - progressive_override::LabelEvaluationError, -}; +use crate::pipeline::progressive_override::LabelEvaluationError; #[derive(Debug)] pub struct PipelineError { @@ -26,18 +23,6 @@ pub struct PipelineError { pub error: PipelineErrorVariant, } -pub trait PipelineErrorFromAcceptHeader { - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError; -} - -impl PipelineErrorFromAcceptHeader for HttpRequest { - #[inline] - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError { - let accept_ok = !self.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); - PipelineError { accept_ok, error } - } -} - #[derive(Debug, thiserror::Error)] pub enum PipelineErrorVariant { // HTTP-related errors @@ -78,7 +63,7 @@ pub enum PipelineErrorVariant { #[error("Failed to execute a plan: {0}")] PlanExecutionError(PlanExecutionError), #[error("Failed to produce a plan: {0}")] - PlannerError(Arc), + PlannerError(PlannerError), #[error(transparent)] LabelEvaluationError(LabelEvaluationError), @@ -156,11 +141,11 @@ pub struct FailedExecutionResult { pub errors: Option>, } -impl PipelineError { - pub fn into_response(self) -> web::HttpResponse { - let status = self.error.default_status_code(self.accept_ok); +impl From for Response { + fn from(val: PipelineError) -> Self { + let status = val.error.default_status_code(val.accept_ok); - if let PipelineErrorVariant::ValidationErrors(validation_errors) = self.error { + if let PipelineErrorVariant::ValidationErrors(validation_errors) = val.error { let validation_error_result = FailedExecutionResult { errors: Some(validation_errors.iter().map(|error| error.into()).collect()), }; @@ -168,8 +153,8 @@ impl PipelineError { return ResponseBuilder::new(status).json(&validation_error_result); } - let code = self.error.graphql_error_code(); - let message = self.error.graphql_error_message(); + let code = val.error.graphql_error_code(); + let message = val.error.graphql_error_message(); let graphql_error = GraphQLError::from_message_and_extensions( message, diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 42ace79ce..991f503c1 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -1,16 +1,15 @@ use std::collections::HashMap; -use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; use crate::shared_state::RouterSharedState; -use hive_router_plan_executor::execute_query_plan; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; use ntex::web::HttpRequest; @@ -24,16 +23,18 @@ enum ExposeQueryPlanMode { DryRun, } +#[allow(clippy::too_many_arguments)] #[inline] pub async fn execute_plan( req: &HttpRequest, supergraph: &SupergraphData, - app_state: &Arc, - normalized_payload: &Arc, - query_plan_payload: &Arc, + app_state: &RouterSharedState, + normalized_payload: &GraphQLNormalizationPayload, + query_plan_payload: &QueryPlan, variable_payload: &CoerceVariablesPayload, client_request_details: &ClientRequestDetails<'_, '_>, -) -> Result { + plugin_manager: PluginManager<'_>, +) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { @@ -65,7 +66,7 @@ pub async fn execute_plan( metadata: &supergraph.metadata, }; - let jwt_forward_plan: Option = if app_state + let jwt_auth_forwarding: Option = if app_state .router_config .jwt .is_jwt_extensions_forwarding_enabled() @@ -79,13 +80,15 @@ pub async fn execute_plan( .forward_claims_to_upstream_extensions .field_name, ) - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))? + .map_err(PipelineErrorVariant::JwtForwardingError)? } else { None }; - execute_query_plan(QueryPlanExecutionContext { + let ctx = QueryPlanExecutionContext { + plugin_manager: &plugin_manager, query_plan: query_plan_payload, + operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, variable_values: &variable_payload.variables_map, @@ -93,12 +96,12 @@ pub async fn execute_plan( client_request: client_request_details, introspection_context: &introspection_context, operation_type_name: normalized_payload.root_type_name, - jwt_auth_forwarding: &jwt_forward_plan, + jwt_auth_forwarding, executors: &supergraph.subgraph_executor_map, - }) - .await - .map_err(|err| { + }; + + ctx.execute_query_plan().await.map_err(|err| { tracing::error!("Failed to execute query plan: {}", err); - req.new_pipeline_error(PipelineErrorVariant::PlanExecutionError(err)) + PipelineErrorVariant::PlanExecutionError(err) }) } diff --git a/bin/router/src/pipeline/execution_request.rs b/bin/router/src/pipeline/execution_request.rs deleted file mode 100644 index c17a6f355..000000000 --- a/bin/router/src/pipeline/execution_request.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::collections::HashMap; - -use http::Method; -use ntex::util::Bytes; -use ntex::web::types::Query; -use ntex::web::HttpRequest; -use serde::{Deserialize, Deserializer}; -use sonic_rs::Value; -use tracing::{trace, warn}; - -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::header::AssertRequestJson; - -#[derive(serde::Deserialize, Debug)] -struct GETQueryParams { - pub query: Option, - #[serde(rename = "camelCase")] - pub operation_name: Option, - pub variables: Option, - pub extensions: Option, -} - -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ExecutionRequest { - pub query: String, - pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] - pub variables: HashMap, - // TODO: We don't use extensions yet, but we definitely will in the future. - #[allow(dead_code)] - pub extensions: Option>, -} - -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) -} - -impl TryInto for GETQueryParams { - type Error = PipelineErrorVariant; - - fn try_into(self) -> Result { - let query = match self.query { - Some(q) => q, - None => return Err(PipelineErrorVariant::GetMissingQueryParam("query")), - }; - - let variables = match self.variables.as_deref() { - Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { - Ok(vars) => vars, - Err(e) => { - return Err(PipelineErrorVariant::FailedToParseVariables(e)); - } - }, - _ => HashMap::new(), - }; - - let extensions = match self.extensions.as_deref() { - Some(e_str) if !e_str.is_empty() => match sonic_rs::from_str(e_str) { - Ok(exts) => Some(exts), - Err(e) => { - return Err(PipelineErrorVariant::FailedToParseExtensions(e)); - } - }, - _ => None, - }; - - let execution_request = ExecutionRequest { - query, - operation_name: self.operation_name, - variables, - extensions, - }; - - Ok(execution_request) - } -} - -#[inline] -pub async fn get_execution_request( - req: &mut HttpRequest, - body_bytes: Bytes, -) -> Result { - let http_method = req.method(); - let execution_request: ExecutionRequest = match *http_method { - Method::GET => { - trace!("processing GET GraphQL operation"); - let query_params_str = req.uri().query().ok_or_else(|| { - req.new_pipeline_error(PipelineErrorVariant::GetInvalidQueryParams) - })?; - let query_params = Query::::from_query(query_params_str) - .map_err(|e| { - req.new_pipeline_error(PipelineErrorVariant::GetUnprocessableQueryParams(e)) - })? - .0; - - trace!("parsed GET query params: {:?}", query_params); - - query_params - .try_into() - .map_err(|err| req.new_pipeline_error(err))? - } - Method::POST => { - trace!("Processing POST GraphQL request"); - - req.assert_json_content_type()?; - - let execution_request = unsafe { - sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { - warn!("Failed to parse body: {}", e); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseBody(e)) - })? - }; - - execution_request - } - _ => { - warn!("unsupported HTTP method: {}", http_method); - - return Err( - req.new_pipeline_error(PipelineErrorVariant::UnsupportedHttpMethod( - http_method.to_owned(), - )), - ); - } - }; - - Ok(execution_request) -} diff --git a/bin/router/src/pipeline/header.rs b/bin/router/src/pipeline/header.rs index 92a591235..19ea8c7af 100644 --- a/bin/router/src/pipeline/header.rs +++ b/bin/router/src/pipeline/header.rs @@ -6,7 +6,7 @@ use lazy_static::lazy_static; use ntex::web::HttpRequest; use tracing::{trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; lazy_static! { pub static ref APPLICATION_JSON_STR: &'static str = "application/json"; @@ -34,31 +34,29 @@ impl RequestAccepts for HttpRequest { } pub trait AssertRequestJson { - fn assert_json_content_type(&self) -> Result<(), PipelineError>; + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant>; } impl AssertRequestJson for HttpRequest { #[inline] - fn assert_json_content_type(&self) -> Result<(), PipelineError> { + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant> { match self.headers().get(CONTENT_TYPE) { Some(value) => { - let content_type_str = value.to_str().map_err(|_| { - self.new_pipeline_error(PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE)) - })?; + let content_type_str = value + .to_str() + .map_err(|_| PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE))?; if !content_type_str.contains(*APPLICATION_JSON_STR) { warn!( "Invalid content type on a POST request: {}", content_type_str ); - return Err( - self.new_pipeline_error(PipelineErrorVariant::UnsupportedContentType) - ); + return Err(PipelineErrorVariant::UnsupportedContentType); } Ok(()) } None => { trace!("POST without content type detected"); - Err(self.new_pipeline_error(PipelineErrorVariant::MissingContentTypeHeader)) + Err(PipelineErrorVariant::MissingContentTypeHeader) } } } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 2b4721972..4d2f264ee 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,8 +1,16 @@ use std::sync::Arc; -use hive_router_plan_executor::execution::{ - client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, +use hive_router_plan_executor::{ + execution::{ + client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, + plan::PlanExecutionOutput, + }, + hooks::{ + on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_supergraph_load::SupergraphData, + }, + plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, + plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, @@ -18,29 +26,29 @@ use crate::{ pipeline::{ coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, - error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, + deserialize_graphql_params::{deserialize_graphql_params, GetQueryStr}, + error::PipelineErrorVariant, execution::execute_plan, - execution_request::get_execution_request, header::{ RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, TEXT_HTML_CONTENT_TYPE, }, normalize::normalize_request_with_cache, - parser::parse_operation_with_cache, + parser::{parse_operation_with_cache, ParseResult}, progressive_override::request_override_context, - query_plan::plan_operation_with_cache, + query_plan::{plan_operation_with_cache, QueryPlanResult}, validation::validate_operation_with_cache, }, - schema_state::{SchemaState, SupergraphData}, + schema_state::SchemaState, shared_state::RouterSharedState, }; pub mod coerce_variables; pub mod cors; pub mod csrf_prevention; +pub mod deserialize_graphql_params; pub mod error; pub mod execution; -pub mod execution_request; pub mod header; pub mod normalize; pub mod parser; @@ -52,96 +60,180 @@ static GRAPHIQL_HTML: &str = include_str!("../../static/graphiql.html"); #[inline] pub async fn graphql_request_handler( - req: &mut HttpRequest, + req: &HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, - schema_state: &Arc, -) -> web::HttpResponse { + shared_state: &RouterSharedState, + schema_state: &SchemaState, +) -> Result { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { - return web::HttpResponse::Ok() + return Ok(web::HttpResponse::Ok() .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) - .body(GRAPHIQL_HTML); + .body(GRAPHIQL_HTML)); } else { - return web::HttpResponse::NotFound().into(); + return Ok(web::HttpResponse::NotFound().into()); } } - if let Some(jwt) = &shared_state.jwt_auth_runtime { + let jwt_context = if let Some(jwt) = &shared_state.jwt_auth_runtime { match jwt.validate_request(req) { - Ok(_) => (), - Err(err) => return err.make_response(), + Ok(jwt_context) => jwt_context, + Err(err) => return Ok(err.make_response()), } - } + } else { + None + }; - match execute_pipeline(req, body_bytes, supergraph, shared_state, schema_state).await { - Ok(response) => { - let response_bytes = Bytes::from(response.body); - let response_headers = response.headers; - - let response_content_type: &'static HeaderValue = - if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { - &APPLICATION_GRAPHQL_RESPONSE_JSON - } else { - &APPLICATION_JSON - }; - - let mut response_builder = web::HttpResponse::Ok(); - for (header_name, header_value) in response_headers { - if let Some(header_name) = header_name { - response_builder.header(header_name, header_value); - } - } + let response_content_type: &'static HeaderValue = + if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { + &APPLICATION_GRAPHQL_RESPONSE_JSON + } else { + &APPLICATION_JSON + }; - response_builder - .header(http::header::CONTENT_TYPE, response_content_type) - .body(response_bytes) + let plugin_context = req + .extensions() + .get::>() + .cloned() + .expect("Plugin manager should be loaded"); + + let plugin_manager = PluginManager { + plugins: shared_state.plugins.clone(), + router_http_request: RouterHttpRequest { + uri: req.uri(), + method: req.method(), + version: req.version(), + headers: req.headers(), + match_info: req.match_info(), + query_string: req.query_string(), + path: req.path(), + }, + context: plugin_context, + }; + + let response = execute_pipeline( + req, + body_bytes, + supergraph, + shared_state, + schema_state, + jwt_context, + plugin_manager, + ) + .await?; + let response_bytes = Bytes::from(response.body); + let response_headers = response.headers; + + let mut response_builder = web::HttpResponse::Ok(); + for (header_name, header_value) in response_headers { + if let Some(header_name) = header_name { + response_builder.header(header_name, header_value); } - Err(err) => err.into_response(), } + + Ok(response_builder + .header(http::header::CONTENT_TYPE, response_content_type) + .body(response_bytes)) } #[inline] #[allow(clippy::await_holding_refcell_ref)] pub async fn execute_pipeline( - req: &mut HttpRequest, - body_bytes: Bytes, + req: &HttpRequest, + body: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, - schema_state: &Arc, -) -> Result { + shared_state: &RouterSharedState, + schema_state: &SchemaState, + jwt_context: Option, + plugin_manager: PluginManager<'_>, +) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; - let mut execution_request = get_execution_request(req, body_bytes).await?; - let parser_payload = parse_operation_with_cache(req, shared_state, &execution_request).await?; - validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) - .await?; + /* Handle on_deserialize hook in the plugins - START */ + let mut deserialization_end_callbacks = vec![]; + let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, + body, + graphql_params: None, + }; + for plugin in shared_state.plugins.as_ref() { + let result = plugin.on_graphql_params(deserialization_payload).await; + deserialization_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(callback) => { + deserialization_end_callbacks.push(callback); + } + } + } + let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { + deserialize_graphql_params(req, deserialization_payload.body) + .expect("Failed to parse execution request") + }); + + let mut payload = OnGraphQLParamsEndPayload { + graphql_params, + context: &plugin_manager.context, + }; + for deserialization_end_callback in deserialization_end_callbacks { + let result = deserialization_end_callback(payload); + payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + let mut graphql_params = payload.graphql_params; + /* Handle on_deserialize hook in the plugins - END */ - let normalize_payload = normalize_request_with_cache( - req, + let parser_result = + parse_operation_with_cache(shared_state, &graphql_params, &plugin_manager).await?; + + let parser_payload = match parser_result { + ParseResult::Payload(payload) => payload, + ParseResult::Response(response) => { + return Ok(response); + } + }; + + validate_operation_with_cache( supergraph, schema_state, - &execution_request, + shared_state, &parser_payload, + &plugin_manager, ) .await?; + + let normalize_payload = + normalize_request_with_cache(supergraph, schema_state, &graphql_params, &parser_payload) + .await?; + let variable_payload = - coerce_request_variables(req, supergraph, &mut execution_request, &normalize_payload)?; + coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); - let req_extensions = req.extensions(); - let jwt_context = req_extensions.get::(); let jwt_request_details = match jwt_context { Some(jwt_context) => JwtRequestDetails::Authenticated { - token: jwt_context.token_raw.as_str(), - prefix: jwt_context.token_prefix.as_deref(), scopes: jwt_context.extract_scopes(), - claims: &jwt_context + claims: jwt_context .get_claims_value() - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?, + .map_err(PipelineErrorVariant::JwtForwardingError)?, + token: jwt_context.token_raw, + prefix: jwt_context.token_prefix, }, None => JwtRequestDetails::Unauthenticated, }; @@ -158,26 +250,33 @@ pub async fn execute_pipeline( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: &execution_request.query, + query: graphql_params.get_query()?, }, - jwt: &jwt_request_details, + jwt: jwt_request_details, }; let progressive_override_ctx = request_override_context( &shared_state.override_labels_evaluator, &client_request_details, ) - .map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?; + .map_err(PipelineErrorVariant::LabelEvaluationError)?; - let query_plan_payload = plan_operation_with_cache( - req, + let query_plan_result = plan_operation_with_cache( supergraph, schema_state, &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, + shared_state, + &plugin_manager, ) .await?; + let query_plan_payload = match query_plan_result { + QueryPlanResult::QueryPlan(plan) => plan, + QueryPlanResult::Response(response) => { + return Ok(response); + } + }; let execution_result = execute_plan( req, @@ -187,6 +286,7 @@ pub async fn execute_pipeline( &query_plan_payload, &variable_payload, &client_request_details, + plugin_manager, ) .await?; diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index 4fc2cc5ef..97cbb80ac 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -1,17 +1,17 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; use hive_router_query_planner::ast::operation::OperationDefinition; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::SchemaState; use tracing::{error, trace}; #[derive(Debug)] @@ -25,16 +25,15 @@ pub struct GraphQLNormalizationPayload { #[inline] pub async fn normalize_request_with_cache( - req: &HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - execution_params: &ExecutionRequest, + schema_state: &SchemaState, + graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, -) -> Result, PipelineError> { - let cache_key = match &execution_params.operation_name { +) -> Result, PipelineErrorVariant> { + let cache_key = match &graphql_params.operation_name { Some(operation_name) => { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); operation_name.hash(&mut hasher); hasher.finish() } @@ -54,7 +53,7 @@ pub async fn normalize_request_with_cache( None => match normalize_operation( &supergraph.planner.supergraph, &parser_payload.parsed_operation, - execution_params.operation_name.as_deref(), + graphql_params.operation_name.as_deref(), ) { Ok(doc) => { trace!( @@ -86,7 +85,7 @@ pub async fn normalize_request_with_cache( error!("Failed to normalize GraphQL operation: {}", err); trace!("{:?}", err); - Err(req.new_pipeline_error(PipelineErrorVariant::NormalizationError(err))) + Err(PipelineErrorVariant::NormalizationError(err)) } }, } diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 6e8a37141..aebbd6beb 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,12 +2,18 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_parse::{ + OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, +}; +use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; +use crate::pipeline::deserialize_graphql_params::GetQueryStr; +use crate::pipeline::error::PipelineErrorVariant; use crate::shared_state::RouterSharedState; use tracing::{error, trace}; @@ -17,28 +23,85 @@ pub struct GraphQLParserPayload { pub cache_key: u64, } +pub enum ParseResult { + Payload(GraphQLParserPayload), + Response(PlanExecutionOutput), +} + #[inline] pub async fn parse_operation_with_cache( - req: &HttpRequest, - app_state: &Arc, - execution_params: &ExecutionRequest, -) -> Result { + app_state: &RouterSharedState, + graphql_params: &GraphQLParams, + plugin_manager: &PluginManager<'_>, +) -> Result { let cache_key = { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); hasher.finish() }; + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, + graphql_params, + document: None, + }; let parsed_operation = if let Some(cached) = app_state.parse_cache.get(&cache_key).await { trace!("Found cached parsed operation for query"); cached } else { - let parsed = safe_parse_operation(&execution_params.query).map_err(|err| { - error!("Failed to parse GraphQL operation: {}", err); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) - })?; - trace!("sucessfully parsed GraphQL operation"); - let parsed_arc = Arc::new(parsed); + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_graphql_parse(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Ok(ParseResult::Response(response)); + } + ControlFlowResult::OnEnd(callback) => { + // store the callback to be called later + on_end_callbacks.push(callback); + } + } + } + + let document = match start_payload.document { + Some(parsed) => parsed, + None => { + let query_str = graphql_params.get_query()?; + let parsed = safe_parse_operation(query_str).map_err(|err| { + error!("Failed to parse GraphQL operation: {}", err); + PipelineErrorVariant::FailedToParseOperation(err) + })?; + trace!("successfully parsed GraphQL operation"); + parsed + } + }; + let mut end_payload = OnGraphQLParseEndPayload { document }; + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Ok(ParseResult::Response(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!(); + } + } + } + let document = end_payload.document; + /* Handle on_graphql_parse hook in the plugins - END */ + + let parsed_arc = Arc::new(document); app_state .parse_cache .insert(cache_key, parsed_arc.clone()) @@ -46,8 +109,8 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok(GraphQLParserPayload { + Ok(ParseResult::Payload(GraphQLParserPayload { parsed_operation, cache_key, - }) + })) } diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..4743d8672 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec, 'req>( +pub fn request_override_context( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec, 'req>, + client_request_details: &ClientRequestDetails<'_, '_>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index b2f730be7..d1f83ca7d 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -1,24 +1,43 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::SchemaState; +use crate::RouterSharedState; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_query_plan::{ + OnQueryPlanEndPayload, OnQueryPlanStartPayload, +}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; +pub enum QueryPlanResult { + QueryPlan(Arc), + Response(PlanExecutionOutput), +} + +pub enum QueryPlanGetterError { + Planner(PlannerError), + Response(PlanExecutionOutput), +} + #[inline] -pub async fn plan_operation_with_cache( - req: &HttpRequest, +pub async fn plan_operation_with_cache<'req>( supergraph: &SupergraphData, - schema_state: &Arc, - normalized_operation: &Arc, + schema_state: &SchemaState, + normalized_operation: &GraphQLNormalizationPayload, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, -) -> Result, PipelineError> { + app_state: &RouterSharedState, + plugin_manager: &PluginManager<'req>, +) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -30,7 +49,7 @@ pub async fn plan_operation_with_cache( let plan_result = schema_state .plan_cache - .try_get_with(plan_cache_key, async move { + .try_get_with(plan_cache_key, async { if is_pure_introspection { return Ok(Arc::new(QueryPlan { kind: "QueryPlan".to_string(), @@ -38,20 +57,76 @@ pub async fn plan_operation_with_cache( })); } - supergraph - .planner - .plan_from_normalized_operation( - filtered_operation_for_plan, - (&request_override_context.clone()).into(), - cancellation_token, - ) - .map(Arc::new) + /* Handle on_query_plan hook in the plugins - START */ + let mut start_payload = OnQueryPlanStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, + filtered_operation_for_plan, + planner_override_context: (&request_override_context.clone()).into(), + cancellation_token, + query_plan: None, + planner: &supergraph.planner, + }; + + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_query_plan(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + let query_plan = match start_payload.query_plan { + Some(plan) => plan, + None => supergraph + .planner + .plan_from_normalized_operation( + filtered_operation_for_plan, + (&request_override_context.clone()).into(), + cancellation_token, + ) + .map_err(QueryPlanGetterError::Planner)?, + }; + + let mut end_payload = OnQueryPlanEndPayload { query_plan }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + } + } + } + + Ok(Arc::new(end_payload.query_plan)) + /* Handle on_query_plan hook in the plugins - END */ }) .await; match plan_result { - Ok(plan) => Ok(plan), - Err(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), + Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), + Err(e) => match e.as_ref() { + QueryPlanGetterError::Planner(e) => Err(PipelineErrorVariant::PlannerError(e.clone())), + QueryPlanGetterError::Response(response) => { + Ok(QueryPlanResult::Response(response.clone())) + } + }, } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 85d44c2f1..92cb1eb6f 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -1,21 +1,27 @@ use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; -use ntex::web::HttpRequest; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_graphql_validation::{ + OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginManager; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( - req: &HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - app_state: &Arc, + schema_state: &SchemaState, + app_state: &RouterSharedState, parser_payload: &GraphQLParserPayload, -) -> Result<(), PipelineError> { + plugin_manager: &PluginManager<'_>, +) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -37,12 +43,59 @@ pub async fn validate_operation_with_cache( parser_payload.cache_key ); - let res = validate( + /* Handle on_graphql_validate hook in the plugins - START */ + let mut start_payload = OnGraphQLValidationStartPayload::new( + plugin_manager, consumer_schema_ast, &parser_payload.parsed_operation, &app_state.validation_plan, ); - let arc_res = Arc::new(res); + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_graphql_validation(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let errors = match start_payload.errors { + Some(errors) => errors, + None => validate( + consumer_schema_ast, + start_payload.document, + start_payload.get_validation_plan(), + ), + }; + + let mut end_payload = OnGraphQLValidationEndPayload { errors }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + } + } + } + /* Handle on_graphql_validate hook in the plugins - END */ + + let arc_res = Arc::new(end_payload.errors); schema_state .validate_cache @@ -59,10 +112,8 @@ pub async fn validate_operation_with_cache( ); trace!("Validation errors: {:?}", validation_result); - return Err( - req.new_pipeline_error(PipelineErrorVariant::ValidationErrors(validation_result)) - ); + return Err(PipelineErrorVariant::ValidationErrors(validation_result)); } - Ok(()) + Ok(None) } diff --git a/bin/router/src/plugins/mod.rs b/bin/router/src/plugins/mod.rs new file mode 100644 index 000000000..3753246b2 --- /dev/null +++ b/bin/router/src/plugins/mod.rs @@ -0,0 +1 @@ +pub mod plugins_service; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs new file mode 100644 index 000000000..223a702ee --- /dev/null +++ b/bin/router/src/plugins/plugins_service.rs @@ -0,0 +1,121 @@ +use std::sync::Arc; + +use hive_router_plan_executor::{ + execution::plan::PlanExecutionOutput, + hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + plugin_context::PluginContext, + plugin_trait::ControlFlowResult, +}; +use http::StatusCode; +use ntex::{ + http::ResponseBuilder, + service::{Service, ServiceCtx}, + web::{self, DefaultError}, + Middleware, +}; + +use crate::RouterSharedState; + +pub struct PluginService; + +impl Middleware for PluginService { + type Service = PluginMiddleware; + + fn create(&self, service: S) -> Self::Service { + PluginMiddleware { service } + } +} + +pub struct PluginMiddleware { + // This is special: We need this to avoid lifetime issues. + service: S, +} + +impl Service> for PluginMiddleware +where + S: Service, Response = web::WebResponse, Error = web::Error>, +{ + type Response = web::WebResponse; + type Error = S::Error; + + ntex::forward_ready!(service); + + async fn call( + &self, + req: web::WebRequest, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + let plugins = req + .app_state::>() + .map(|shared_state| shared_state.plugins.clone()); + + if let Some(plugins) = plugins { + let plugin_context = Arc::new(PluginContext::default()); + req.extensions_mut().insert(plugin_context.clone()); + let mut start_payload = OnHttpRequestPayload { + router_http_request: req, + context: &plugin_context, + }; + + let mut on_end_callbacks = vec![]; + + let mut early_response: Option = None; + for plugin in plugins.iter() { + let result = plugin.on_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + ControlFlowResult::EndResponse(response) => { + early_response = Some(response); + break; + } + } + } + + let req = start_payload.router_http_request; + + let response = if let Some(early_response) = early_response { + let mut builder = ResponseBuilder::new(StatusCode::OK); + for (key, value) in early_response.headers.iter() { + builder.header(key, value); + } + let res = builder.body(early_response.body); + req.into_response(res) + } else { + ctx.call(&self.service, req).await? + }; + + let mut end_payload = OnHttpResponsePayload { + response, + context: &plugin_context, + }; + + for callback in on_end_callbacks.into_iter().rev() { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(_response) => { + // Short-circuit the request with the provided response + unimplemented!() + } + ControlFlowResult::OnEnd(_) => { + // This should not happen + unreachable!(); + } + } + } + + return Ok(end_payload.response); + } + + ctx.call(&self.service, req).await + } +} diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index f14cc6cf0..69db5db3b 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,10 +1,14 @@ use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; -use graphql_tools::validation::utils::ValidationError; +use graphql_tools::{static_graphql::schema::Document, validation::utils::ValidationError}; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ executors::error::SubgraphExecutorError, - introspection::schema::{SchemaMetadata, SchemaWithMetadata}, + hooks::on_supergraph_load::{ + OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData, + }, + introspection::schema::SchemaWithMetadata, + plugin_trait::{ControlFlowResult, RouterPlugin}, SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -26,6 +30,7 @@ use crate::{ base::{LoadSupergraphError, ReloadSupergraphResult, SupergraphLoader}, resolve_from_config, }, + RouterSharedState, }; pub struct SchemaState { @@ -35,12 +40,6 @@ pub struct SchemaState { pub normalize_cache: Cache>, } -pub struct SupergraphData { - pub metadata: SchemaMetadata, - pub planner: Planner, - pub subgraph_executor_map: SubgraphExecutorMap, -} - #[derive(Debug, thiserror::Error)] pub enum SupergraphManagerError { #[error("Failed to load supergraph: {0}")] @@ -65,6 +64,7 @@ impl SchemaState { pub async fn new_from_config( bg_tasks_manager: &mut BackgroundTasksManager, router_config: Arc, + app_state: Arc, ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new(&router_config.supergraph, tx)?; @@ -85,9 +85,62 @@ impl SchemaState { while let Some(new_sdl) = rx.recv().await { debug!("Received new supergraph SDL, building new supergraph state..."); - match Self::build_data(router_config.clone(), &new_sdl) { - Ok(new_data) => { - swappable_data_spawn_clone.store(Arc::new(Some(new_data))); + let new_ast = parse_schema(&new_sdl); + + let mut start_payload = OnSupergraphLoadStartPayload { + current_supergraph_data: swappable_data_spawn_clone.clone(), + new_ast, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_supergraph_reload(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let new_ast = start_payload.new_ast; + + match Self::build_data(router_config.clone(), &new_ast, app_state.plugins.clone()) { + Ok(new_supergraph_data) => { + let mut end_payload = OnSupergraphLoadEndPayload { + new_supergraph_data, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(_) => { + unreachable!( + "Plugins should not end supergraph reload processing" + ); + } + ControlFlowResult::OnEnd(_) => { + unreachable!( + "End callbacks should not register further end callbacks" + ); + } + } + } + + let new_supergraph_data = end_payload.new_supergraph_data; + + swappable_data_spawn_clone.store(Arc::new(Some(new_supergraph_data))); debug!("Supergraph updated successfully"); task_plan_cache.invalidate_all(); @@ -112,15 +165,16 @@ impl SchemaState { fn build_data( router_config: Arc, - supergraph_sdl: &str, + parsed_supergraph_sdl: &Document, + plugins: Arc>>, ) -> Result { - let parsed_supergraph_sdl = parse_schema(supergraph_sdl); - let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl); - let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; + let supergraph_state = SupergraphState::new(parsed_supergraph_sdl); + let planner = Planner::new_from_supergraph(parsed_supergraph_sdl)?; let metadata = planner.consumer_schema.schema_metadata(); let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, router_config, + plugins.clone(), )?; Ok(SupergraphData { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index f36bda6cd..877ffa0e3 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -3,6 +3,7 @@ use hive_router_config::HiveRouterConfig; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; +use hive_router_plan_executor::plugin_trait::RouterPlugin; use moka::future::Cache; use std::sync::Arc; @@ -18,6 +19,7 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, + pub plugins: Arc>>, } impl RouterSharedState { @@ -36,6 +38,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, + plugins: Arc::new(vec![]), }) } } diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 27f7af1bf..07aefd5b6 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -30,10 +30,14 @@ xxhash-rust = { workspace = true } tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } vrl = { workspace = true } +reqwest = { workspace = true, features = ["multipart"] } ahash = "0.8.12" regex-automata = "0.4.10" strum = { version = "0.27.2", features = ["derive"] } + +arc-swap = "1.7.1" +ntex = { version = "2", features = ["tokio"] } ntex-http = "0.1.15" ordered-float = "4.2.0" hyper-tls = { version = "0.6.0", features = ["vendored"] } @@ -49,6 +53,9 @@ itoa = "1.0.15" ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" +redis = "0.32.7" +multer = "3.1.0" +futures-util = "0.3.31" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 6985376cc..20b7dcf98 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -18,14 +18,14 @@ pub struct ClientRequestDetails<'exec, 'req> { pub url: &'req http::Uri, pub headers: &'req NtexHeaderMap, pub operation: OperationDetails<'exec>, - pub jwt: &'exec JwtRequestDetails<'req>, + pub jwt: JwtRequestDetails, } -pub enum JwtRequestDetails<'exec> { +pub enum JwtRequestDetails { Authenticated { - token: &'exec str, - prefix: Option<&'exec str>, - claims: &'exec sonic_rs::Value, + token: String, + prefix: Option, + claims: sonic_rs::Value, scopes: Option>, }, Unauthenticated, @@ -67,7 +67,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ])); // .request.jwt - let jwt_value = match details.jwt { + let jwt_value = match &details.jwt { JwtRequestDetails::Authenticated { token, prefix, @@ -78,7 +78,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ("token".into(), token.to_string().into()), ( "prefix".into(), - prefix.unwrap_or_default().to_string().into(), + prefix.as_deref().unwrap_or_default().to_string().into(), ), ("claims".into(), sonic_value_to_vrl_value(claims)), ( diff --git a/lib/executor/src/execution/jwt_forward.rs b/lib/executor/src/execution/jwt_forward.rs index 24c19ff9f..9aefc601c 100644 --- a/lib/executor/src/execution/jwt_forward.rs +++ b/lib/executor/src/execution/jwt_forward.rs @@ -8,7 +8,7 @@ pub struct JwtAuthForwardingPlan { pub extension_field_value: Value, } -impl JwtRequestDetails<'_> { +impl JwtRequestDetails { pub fn build_forwarding_plan( &self, extension_field_name: &str, diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index f86356312..ad1a46744 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -2,9 +2,12 @@ use std::collections::{BTreeSet, HashMap}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use hive_router_query_planner::planner::plan_nodes::{ - ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode, - QueryPlan, SequenceNode, +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + planner::plan_nodes::{ + ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, + PlanNode, QueryPlan, SequenceNode, + }, }; use http::HeaderMap; use serde::Deserialize; @@ -19,7 +22,7 @@ use crate::{ rewrites::FetchRewriteExt, }, executors::{ - common::{HttpExecutionRequest, HttpExecutionResponse}, + common::{HttpExecutionResponse, SubgraphExecutionRequest}, map::SubgraphExecutorMap, }, headers::{ @@ -27,10 +30,13 @@ use crate::{ request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, }, + hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, introspection::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, }, + plugin_context::PluginManager, + plugin_trait::ControlFlowResult, projection::{ plan::FieldProjectionPlan, request::{project_requires, RequestProjectionContext}, @@ -49,7 +55,9 @@ use crate::{ }; pub struct QueryPlanExecutionContext<'exec, 'req> { + pub plugin_manager: &'exec PluginManager<'exec>, pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, @@ -58,67 +66,125 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, pub executors: &'exec SubgraphExecutorMap, - pub jwt_auth_forwarding: &'exec Option, + pub jwt_auth_forwarding: Option, } +#[derive(Clone)] pub struct PlanExecutionOutput { pub body: Vec, pub headers: HeaderMap, + pub status: http::StatusCode, } -pub async fn execute_query_plan<'exec, 'req>( - ctx: QueryPlanExecutionContext<'exec, 'req>, -) -> Result { - let init_value = if let Some(introspection_query) = ctx.introspection_context.query { - resolve_introspection(introspection_query, ctx.introspection_context) - } else { - Value::Null - }; +impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { + pub async fn execute_query_plan(self) -> Result { + let init_value = if let Some(introspection_query) = self.introspection_context.query { + resolve_introspection(introspection_query, self.introspection_context) + } else { + Value::Null + }; - let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); - let executor = Executor::new( - ctx.variable_values, - ctx.executors, - ctx.introspection_context.metadata, - ctx.client_request, - ctx.headers_plan, - ctx.jwt_auth_forwarding, - // Deduplicate subgraph requests only if the operation type is a query - ctx.operation_type_name == "Query", - ); - - if ctx.query_plan.node.is_some() { - executor - .execute(&mut exec_ctx, ctx.query_plan.node.as_ref()) - .await?; - } + let dedupe_subgraph_requests = self.operation_type_name == "Query"; + let mut start_payload = OnExecuteStartPayload { + router_http_request: &self.plugin_manager.router_http_request, + context: &self.plugin_manager.context, + query_plan: self.query_plan, + operation_for_plan: self.operation_for_plan, + data: init_value, + errors: Vec::new(), + extensions: self.extensions.clone(), + variable_values: self.variable_values, + dedupe_subgraph_requests, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in self.plugin_manager.plugins.iter() { + let result = plugin.on_execute(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let query_plan = start_payload.query_plan; + + let init_value = start_payload.data; + + let mut exec_ctx = ExecutionContext::new(query_plan, init_value); + let executor = Executor::new( + self.variable_values, + self.executors, + self.introspection_context.metadata, + self.client_request, + self.headers_plan, + self.jwt_auth_forwarding, + // Deduplicate subgraph requests only if the operation type is a query + self.operation_type_name == "Query", + self.plugin_manager, + ); - let mut response_headers = HeaderMap::new(); - modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + if query_plan.node.is_some() { + executor + .execute(&mut exec_ctx, query_plan.node.as_ref()) + .await?; + } + + let mut response_headers = HeaderMap::new(); + modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + .with_plan_context(LazyPlanContext { + subgraph_name: || None, + affected_path: || None, + })?; + + let mut end_payload = OnExecuteEndPayload { + data: exec_ctx.final_response, + errors: exec_ctx.errors, + extensions: start_payload.extensions, + response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(output) => { + return Ok(output); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + + let body = project_by_operation( + &end_payload.data, + end_payload.errors, + &self.extensions, + self.operation_type_name, + self.projection_plan, + self.variable_values, + end_payload.response_size_estimate, + ) .with_plan_context(LazyPlanContext { subgraph_name: || None, affected_path: || None, })?; - let final_response = &exec_ctx.final_response; - let body = project_by_operation( - final_response, - exec_ctx.errors, - &ctx.extensions, - ctx.operation_type_name, - ctx.projection_plan, - ctx.variable_values, - exec_ctx.response_storage.estimate_final_response_size(), - ) - .with_plan_context(LazyPlanContext { - subgraph_name: || None, - affected_path: || None, - })?; - - Ok(PlanExecutionOutput { - body, - headers: response_headers, - }) + Ok(PlanExecutionOutput { + body, + headers: response_headers, + status: http::StatusCode::OK, + }) + } } pub struct Executor<'exec, 'req> { @@ -127,8 +193,9 @@ pub struct Executor<'exec, 'req> { executors: &'exec SubgraphExecutorMap, client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_manager: &'exec PluginManager<'exec>, } struct ConcurrencyScope<'exec, T> { @@ -222,8 +289,9 @@ impl<'exec, 'req> Executor<'exec, 'req> { schema_metadata: &'exec SchemaMetadata, client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_manager: &'exec PluginManager<'exec>, ) -> Self { Executor { variable_values, @@ -233,6 +301,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { headers_plan, dedupe_subgraph_requests, jwt_forwarding_plan, + plugin_manager, } } @@ -700,7 +769,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { let variable_refs = select_fetch_variables(self.variable_values, node.variable_usages.as_ref()); - let mut subgraph_request = HttpExecutionRequest { + let mut subgraph_request = SubgraphExecutionRequest { query: node.operation.document_str.as_str(), dedupe: self.dedupe_subgraph_requests, operation_name: node.operation_name.as_deref(), @@ -722,7 +791,12 @@ impl<'exec, 'req> Executor<'exec, 'req> { subgraph_name: node.service_name.clone(), response: self .executors - .execute(&node.service_name, subgraph_request, self.client_request) + .execute( + &node.service_name, + subgraph_request, + self.client_request, + self.plugin_manager, + ) .await .into(), })) diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index bdcd4d819..6b3c804b5 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -9,7 +9,7 @@ use sonic_rs::Value; pub trait SubgraphExecutor { async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> HttpExecutionResponse; fn to_boxed_arc<'a>(self) -> Arc> @@ -26,7 +26,7 @@ pub type SubgraphExecutorBoxedArc = Arc>; pub type SubgraphRequestExtensions = HashMap; -pub struct HttpExecutionRequest<'a> { +pub struct SubgraphExecutionRequest<'a> { pub query: &'a str, pub dedupe: bool, pub operation_name: Option<&'a str>, @@ -37,7 +37,7 @@ pub struct HttpExecutionRequest<'a> { pub extensions: Option, } -impl HttpExecutionRequest<'_> { +impl SubgraphExecutionRequest<'_> { pub fn add_request_extensions_field(&mut self, key: String, value: Value) { self.extensions .get_or_insert_with(HashMap::new) @@ -45,7 +45,9 @@ impl HttpExecutionRequest<'_> { } } +#[derive(Clone)] pub struct HttpExecutionResponse { pub body: Bytes, pub headers: HeaderMap, + pub status: http::StatusCode, } diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 29b392567..f33cbaedc 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,6 +2,10 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, +}; +use crate::plugin_trait::{ControlFlowResult, RouterPlugin}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -9,8 +13,8 @@ use tokio::sync::OnceCell; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; -use http::HeaderMap; use http::HeaderValue; +use http::{HeaderMap, StatusCode}; use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; @@ -19,7 +23,7 @@ use hyper_util::client::legacy::{connect::HttpConnector, Client}; use tokio::sync::Semaphore; use tracing::debug; -use crate::executors::common::HttpExecutionRequest; +use crate::executors::common::SubgraphExecutionRequest; use crate::executors::error::SubgraphExecutorError; use crate::response::graphql_error::GraphQLError; use crate::utils::consts::CLOSE_BRACE; @@ -28,7 +32,6 @@ use crate::utils::consts::COMMA; use crate::utils::consts::QUOTE; use crate::{executors::common::SubgraphExecutor, json_writer::write_and_escape_string}; -#[derive(Debug)] pub struct HTTPSubgraphExecutor { pub subgraph_name: String, pub endpoint: http::Uri, @@ -37,6 +40,7 @@ pub struct HTTPSubgraphExecutor { pub semaphore: Arc, pub config: Arc, pub in_flight_requests: Arc>, ABuildHasher>>, + pub plugins: Arc>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -52,6 +56,7 @@ impl HTTPSubgraphExecutor { semaphore: Arc, config: Arc, in_flight_requests: Arc>, ABuildHasher>>, + plugins: Arc>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -71,12 +76,13 @@ impl HTTPSubgraphExecutor { semaphore, config, in_flight_requests, + plugins, } } fn build_request_body<'a>( &self, - execution_request: &HttpExecutionRequest<'a>, + execution_request: &SubgraphExecutionRequest<'a>, ) -> Result, SubgraphExecutorError> { let mut body = Vec::with_capacity(4096); body.put(FIRST_QUOTE_STR); @@ -133,57 +139,6 @@ impl HTTPSubgraphExecutor { Ok(body) } - async fn _send_request( - &self, - body: Vec, - headers: HeaderMap, - ) -> Result { - let mut req = hyper::Request::builder() - .method(http::Method::POST) - .uri(&self.endpoint) - .version(Version::HTTP_11) - .body(Full::new(Bytes::from(body))) - .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(self.endpoint.to_string(), e.to_string()) - })?; - - *req.headers_mut() = headers; - - debug!("making http request to {}", self.endpoint.to_string()); - - let res = self.http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) - })?; - - debug!( - "http request to {} completed, status: {}", - self.endpoint.to_string(), - res.status() - ); - - let (parts, body) = res.into_parts(); - let body = body - .collect() - .await - .map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) - })? - .to_bytes(); - - if body.is_empty() { - return Err(SubgraphExecutorError::RequestFailure( - self.endpoint.to_string(), - "Empty response body".to_string(), - )); - } - - Ok(SharedResponse { - status: parts.status, - body, - headers: parts.headers, - }) - } - fn error_to_graphql_bytes(&self, error: SubgraphExecutorError) -> Bytes { let graphql_error: GraphQLError = error.into(); let mut graphql_error = graphql_error.add_subgraph_name(&self.subgraph_name); @@ -207,12 +162,118 @@ impl HTTPSubgraphExecutor { } } +async fn send_request( + http_client: &Client, Full>, + subgraph_name: &str, + endpoint: &http::Uri, + method: http::Method, + body: Vec, + headers: HeaderMap, + plugins: Arc>>, +) -> Result { + let mut req = hyper::Request::builder() + .method(method) + .uri(endpoint) + .version(Version::HTTP_11) + .body(Full::new(Bytes::from(body))) + .map_err(|e| { + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) + })?; + + *req.headers_mut() = headers; + + let mut start_payload = OnSubgraphHttpRequestPayload { + subgraph_name, + request: req, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + // TODO: Fixx + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + debug!("making http request to {}", endpoint.to_string()); + + let req = start_payload.request; + + let res = http_client + .request(req) + .await + .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))?; + + debug!( + "http request to {} completed, status: {}", + endpoint.to_string(), + res.status() + ); + + let (parts, body) = res.into_parts(); + let body = body + .collect() + .await + .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))? + .to_bytes(); + + if body.is_empty() { + return Err(SubgraphExecutorError::RequestFailure( + endpoint.to_string(), + "Empty response body".to_string(), + )); + } + + let response = SharedResponse { + status: parts.status, + body, + headers: parts.headers, + }; + + let mut end_payload = OnSubgraphHttpResponsePayload { response }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(response) => { + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + + Ok(end_payload.response) +} + #[async_trait] impl SubgraphExecutor for HTTPSubgraphExecutor { #[tracing::instrument(skip_all, fields(subgraph_name = self.subgraph_name))] async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> HttpExecutionResponse { let body = match self.build_request_body(&execution_request) { Ok(body) => body, @@ -221,6 +282,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { return HttpExecutionResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, }; } }; @@ -230,26 +292,40 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { headers.insert(key, value.clone()); }); + let method = http::Method::POST; + if !self.config.traffic_shaping.dedupe_enabled || !execution_request.dedupe { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - return match self._send_request(body, headers).await { + return match send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + headers, + self.plugins.clone(), + ) + .await + { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, + status: shared_response.status, }, Err(e) => { self.log_error(&e); HttpExecutionResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, } } }; } - let fingerprint = request_fingerprint(&http::Method::POST, &self.endpoint, &headers, &body); + let fingerprint = request_fingerprint(&method, &self.endpoint, &headers, &body); // Clone the cell from the map, dropping the lock from the DashMap immediately. // Prevents any deadlocks. @@ -266,7 +342,16 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - self._send_request(body, headers).await + send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + headers, + self.plugins.clone(), + ) + .await }; // It's important to remove the entry from the map before returning the result. // This ensures that once the OnceCell is set, no future requests can join it. @@ -280,12 +365,14 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body.clone(), headers: shared_response.headers.clone(), + status: shared_response.status, }, Err(e) => { self.log_error(&e); HttpExecutionResponse { body: self.error_to_graphql_bytes(e.clone()), headers: Default::default(), + status: StatusCode::OK, } } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index a3c297ad1..bf8eb7838 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -30,12 +30,16 @@ use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ common::{ - HttpExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, + HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, + SubgraphExecutorBoxedArc, }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient}, }, + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_context::PluginManager, + plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError, }; @@ -60,10 +64,14 @@ pub struct SubgraphExecutorMap { semaphores_by_origin: DashMap>, max_connections_per_host: usize, in_flight_requests: Arc>, ABuildHasher>>, + plugins: Arc>>, } impl SubgraphExecutorMap { - pub fn new(config: Arc) -> Self { + pub fn new( + config: Arc, + plugins: Arc>>, + ) -> Self { let https = HttpsConnector::new(); let client: HttpClient = Client::builder(TokioExecutor::new()) .pool_timer(TokioTimer::new()) @@ -85,14 +93,16 @@ impl SubgraphExecutorMap { semaphores_by_origin: Default::default(), max_connections_per_host, in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())), + plugins, } } pub fn from_http_endpoint_map( subgraph_endpoint_map: HashMap, config: Arc, + plugins: Arc>>, ) -> Result { - let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone()); + let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), plugins); for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.into_iter() { let endpoint_str = config @@ -115,13 +125,47 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - pub async fn execute<'a, 'req>( + pub async fn execute<'exec, 'req>( &self, subgraph_name: &str, - execution_request: HttpExecutionRequest<'a>, - client_request: &ClientRequestDetails<'a, 'req>, + execution_request: SubgraphExecutionRequest<'exec>, + client_request: &ClientRequestDetails<'exec, 'req>, + plugin_manager: &PluginManager<'req>, ) -> HttpExecutionResponse { - match self.get_or_create_executor(subgraph_name, client_request) { + let mut start_payload = OnSubgraphExecuteStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, + subgraph_name: subgraph_name.to_string(), + execution_request, + execution_result: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in self.plugins.as_ref() { + let result = plugin.on_subgraph_execute(start_payload).await; + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + status: response.status, + }; + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let execution_request = start_payload.execution_request; + + let execution_result = match self.get_or_create_executor(subgraph_name, client_request) { Ok(Some(executor)) => executor.execute(execution_request).await, Err(err) => { error!( @@ -137,7 +181,35 @@ impl SubgraphExecutorMap { ); self.internal_server_error_response("Internal server error".into(), subgraph_name) } + }; + + let mut end_payload = OnSubgraphExecuteEndPayload { + context: &plugin_manager.context, + execution_result, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + status: response.status, + }; + } + ControlFlowResult::OnEnd(_) => { + unreachable!("End callbacks should not register further end callbacks"); + } + } } + + end_payload.execution_result } fn internal_server_error_response( @@ -155,6 +227,7 @@ impl SubgraphExecutorMap { HttpExecutionResponse { body: buffer.freeze(), headers: Default::default(), + status: http::StatusCode::INTERNAL_SERVER_ERROR, } } @@ -324,6 +397,7 @@ impl SubgraphExecutorMap { semaphore, self.config.clone(), self.in_flight_requests.clone(), + self.plugins.clone(), ); self.executors_by_subgraph diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index 62f9fe701..0338bf6df 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -82,7 +82,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -116,7 +116,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); modify_subgraph_request_headers(&plan, "any", &client_details, &mut out).unwrap(); @@ -163,7 +163,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -201,7 +201,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -235,7 +235,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -275,7 +275,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; // For "accounts" subgraph, the specific rule should apply. @@ -319,7 +319,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -384,7 +384,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -448,7 +448,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -505,7 +505,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -563,7 +563,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -622,7 +622,7 @@ mod tests { query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 637ab0d58..7f362be73 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -45,9 +45,9 @@ pub fn modify_subgraph_request_headers( Ok(()) } -pub struct RequestExpressionContext<'a, 'req> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, +pub struct RequestExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, } trait ApplyRequestHeader { @@ -117,7 +117,7 @@ impl ApplyRequestHeader for RequestPropagateRegex { ctx: &RequestExpressionContext, output_headers: &mut HeaderMap, ) -> Result<(), HeaderRuleRuntimeError> { - for (header_name, header_value) in ctx.client_request.headers { + for (header_name, header_value) in ctx.client_request.headers.iter() { if is_denied_header(header_name) { continue; } diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 6a5c34444..b4942837f 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -50,10 +50,10 @@ pub fn apply_subgraph_response_headers( Ok(()) } -pub struct ResponseExpressionContext<'a, 'req> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, - pub subgraph_headers: &'a HeaderMap, +pub struct ResponseExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, + pub subgraph_headers: &'exec HeaderMap, } trait ApplyResponseHeader { diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 4f912a463..bdcbdadc0 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -4,10 +4,11 @@ pub mod executors; pub mod headers; pub mod introspection; pub mod json_writer; +pub mod plugins; pub mod projection; pub mod response; pub mod utils; pub mod variables; -pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; +pub use plugins::*; diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs new file mode 100644 index 000000000..224654bdb --- /dev/null +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -0,0 +1,155 @@ +use ::serde::{Deserialize, Serialize}; +use ahash::HashMap; +use http::{HeaderMap, StatusCode}; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +#[derive(Default, Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxOptions { + /** + * The URL of the GraphQL endpoint that Sandbox introspects on initial load. Sandbox populates its pages using the schema obtained from this endpoint. + * The default value is `http://localhost:4000`. + * You should only pass non-production endpoints to Sandbox. Sandbox is powered by schema introspection, and we recommend [disabling introspection in production](https://www.apollographql.com/blog/graphql/security/why-you-should-disable-graphql-introspection-in-production/). + * To provide a "Sandbox-like" experience for production endpoints, we recommend using either a [public variant](https://www.apollographql.com/docs/graphos/platform/graph-management/variants#public-variants) or the [embedded Explorer](https://www.apollographql.com/docs/graphos/platform/explorer/embed). + */ + pub initial_endpoint: String, + /** + * By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.Set `hideCookieToggle` to `false` to enable users of your embedded Sandbox instance to toggle the **Include cookies** setting. + */ + pub hide_cookie_toggle: bool, + /** + * By default, the embedded Sandbox has a URL input box that is editable by users.Set endpointIsEditable to false to prevent users of your embedded Sandbox instance from changing the endpoint URL. + */ + pub endpoint_is_editable: bool, + /** + * You can set `includeCookies` to `true` if you instead want Sandbox to pass `{ credentials: 'include' }` for its requests.If you pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials).This config option is deprecated in favor of using the connection settings cookie toggle in Sandbox and setting the default value via `initialState.includeCookies`. + */ + pub include_cookies: bool, + /** + * An object containing additional options related to the state of the embedded Sandbox on page load. + */ + pub initial_state: ApolloSandboxInitialStateOptions, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxInitialStateOptions { + /** + * Set this value to `true` if you want Sandbox to pass `{ credentials: 'include' }` for its requests by default.If you set `hideCookieToggle` to `false`, users can override this default setting with the **Include cookies** toggle. (By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.)If you also pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials). + */ + pub include_cookies: bool, + /** + * A URI-encoded operation to populate in Sandbox's editor on load.If you omit this, Sandbox initially loads an example query based on your schema.Example: + * ```js + * initialState: { + * document: ` + * query ExampleQuery { + * books { + * title + * } + * } + * ` + * } + * ``` + */ + pub document: Option, + /** + * A URI-encoded, serialized object containing initial variable values to populate in Sandbox on load.If provided, these variables should apply to the initial query you provide for [`document`](https://www.apollographql.com/docs/apollo-sandbox#document).Example: + * + * ```js + * initialState: { + * variables: { + * userID: "abc123" + * }, + * } + * ``` + */ + pub variables: Option, + /** + * A URI-encoded, serialized object containing initial HTTP header values to populate in Sandbox on load.Example: + * + * + * ```js + * initialState: { + * headers: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub headers: Option, + /** + * The ID of a collection, paired with an operation ID to populate in Sandbox on load. You can find these values from a registered graph in Studio by clicking the **...** menu next to an operation in the Explorer of that graph and selecting **View operation details**.Example: + * + * ```js + * initialState: { + * collectionId: 'abc1234', + * operationId: 'xyz1234' + * } + * ``` + */ + pub collection_id: Option, + pub operation_id: Option, + /** + * If `true`, the embedded Sandbox periodically polls your `initialEndpoint` for schema updates.The default value is `true`.Example: + * + * ```js + * initialState: { + * pollForSchemaUpdates: false; + * } + * ``` + */ + pub poll_for_schema_updates: bool, + /** + * Headers that are applied by default to every operation executed by the embedded Sandbox. Users can turn off the application of these headers, but they can't modify their values.The embedded Sandbox always includes these headers in its introspection queries to your `initialEndpoint`.Example: + * + * ```js + * initialState: { + * sharedHeaders: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub shared_headers: HashMap, +} + +pub struct ApolloSandboxPlugin { + pub options: ApolloSandboxOptions, +} + +impl RouterPlugin for ApolloSandboxPlugin { + fn on_http_request<'req>( + &'req self, + payload: OnHttpRequestPayload<'req>, + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { + if payload.router_http_request.path() == "/apollo-sandbox" { + let config = sonic_rs::to_string(&self.options).unwrap_or_else(|_| "{}".to_string()); + let html = format!( + r#" +
+ + + "#, + config + ); + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/html".parse().unwrap()); + return payload.end_response(PlanExecutionOutput { + body: html.into_bytes(), + headers, + status: StatusCode::OK, + }); + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs new file mode 100644 index 000000000..91a473b5e --- /dev/null +++ b/lib/executor/src/plugins/examples/apq.rs @@ -0,0 +1,62 @@ +use dashmap::DashMap; +use sonic_rs::{JsonContainerTrait, JsonValueTrait}; + +use crate::{ + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct APQPlugin { + cache: DashMap, +} + +#[async_trait::async_trait] +impl RouterPlugin for APQPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + payload.on_end(|mut payload| { + let persisted_query_ext = payload + .graphql_params + .extensions + .as_ref() + .and_then(|ext| ext.get("persistedQuery")) + .and_then(|pq| pq.as_object()); + if let Some(persisted_query_ext) = persisted_query_ext { + match persisted_query_ext.get(&"version").and_then(|v| v.as_str()) { + Some("1") => {} + _ => { + // TODO: Error for unsupported version + return payload.cont(); + } + } + let sha256_hash = match persisted_query_ext + .get(&"sha256Hash") + .and_then(|h| h.as_str()) + { + Some(h) => h, + None => { + return payload.cont(); + } + }; + if let Some(query_param) = &payload.graphql_params.query { + // Store the query in the cache + self.cache + .insert(sha256_hash.to_string(), query_param.to_string()); + } else { + // Try to get the query from the cache + if let Some(cached_query) = self.cache.get(sha256_hash) { + // Update the graphql_params with the cached query + payload.graphql_params.query = Some(cached_query.value().to_string()); + } else { + // Error + return payload.cont(); + } + } + } + + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/async_auth.rs b/lib/executor/src/plugins/examples/async_auth.rs new file mode 100644 index 000000000..ce6d73331 --- /dev/null +++ b/lib/executor/src/plugins/examples/async_auth.rs @@ -0,0 +1,113 @@ +use std::path::PathBuf; + +// From https://github.com/apollographql/router/blob/dev/examples/async-auth/rust/src/allow_client_id_from_file.rs +use serde::Deserialize; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +#[derive(Deserialize)] +pub struct AllowClientIdConfig { + pub header: String, + pub path: String, +} + +pub struct AllowClientIdFromFile { + header_key: String, + allowed_ids_path: PathBuf, +} + +#[async_trait::async_trait] +impl RouterPlugin for AllowClientIdFromFile { + // Whenever it is a GraphQL request, + // We don't use on_http_request here because we want to run this only when it is a GraphQL request + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let header = payload.router_http_request.headers.get(&self.header_key); + match header { + Some(client_id) => { + let client_id_str = client_id.to_str(); + match client_id_str { + Ok(client_id) => { + let allowed_clients: Vec = sonic_rs::from_str( + std::fs::read_to_string(self.allowed_ids_path.clone()) + .unwrap() + .as_str(), + ) + .unwrap(); + + if !allowed_clients.contains(&client_id.to_string()) { + // Prepare an HTTP 403 response with a GraphQL error message + let body = json!( + { + "errors": [ + { + "message": "client-id is not allowed", + "extensions": { + "code": "UNAUTHORIZED_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::FORBIDDEN, + }); + } + } + Err(_not_a_string_error) => { + let message = format!("'{}' value is not a string", self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "BAD_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + None => { + let message = format!("Missing '{}' header", self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "AUTH_ERROR" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::UNAUTHORIZED, + }); + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/context_data.rs b/lib/executor/src/plugins/examples/context_data.rs new file mode 100644 index 000000000..38265fe57 --- /dev/null +++ b/lib/executor/src/plugins/examples/context_data.rs @@ -0,0 +1,67 @@ +// From https://github.com/apollographql/router/blob/dev/examples/context/rust/src/context_data.rs + +use crate::{ + hooks::{ + on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_context::PluginContextMutEntry, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct ContextDataPlugin {} + +pub struct ContextData { + incoming_data: String, + response_count: u64, +} + +#[async_trait::async_trait] +impl RouterPlugin for ContextDataPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let context_data = ContextData { + incoming_data: "world".to_string(), + response_count: 0, + }; + + payload.context.insert(context_data); + + payload.on_end(|payload| { + let mut ctx_data_entry = payload.context.get_mut_entry(); + let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); + if let Some(context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + let ctx_data_entry = payload.context.get_ref_entry(); + let context_data: Option<&ContextData> = ctx_data_entry.get_ref(); + if let Some(context_data) = context_data { + tracing::info!("hello {}", context_data.incoming_data); // Hello world! + let new_header_value = format!("Hello {}", context_data.incoming_data); + payload.execution_request.headers.insert( + "x-hello", + http::HeaderValue::from_str(&new_header_value).unwrap(), + ); + } + payload.on_end(|payload: OnSubgraphExecuteEndPayload<'exec>| { + let mut ctx_data_entry: PluginContextMutEntry = + payload.context.get_mut_entry(); + let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); + if let Some(context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs new file mode 100644 index 000000000..a566d0e9d --- /dev/null +++ b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs @@ -0,0 +1,55 @@ +// Same with https://github.com/apollographql/router/blob/dev/examples/forbid-anonymous-operations/rust/src/forbid_anonymous_operations.rs + +use http::StatusCode; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +pub struct ForbidAnonymousOperations {} + +#[async_trait::async_trait] +impl RouterPlugin for ForbidAnonymousOperations { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let maybe_operation_name = &payload + .graphql_params + .as_ref() + .and_then(|params| params.operation_name.as_ref()); + + if maybe_operation_name.is_none() + || maybe_operation_name + .expect("is_none() has been checked before; qed") + .is_empty() + { + // let's log the error + tracing::error!("Operation is not allowed!"); + + // Prepare an HTTP 400 response with a GraphQL error message + let response_body = json!({ + "errors": [ + { + "message": "Anonymous operations are not allowed", + "extensions": { + "code": "ANONYMOUS_OPERATION" + } + } + ] + }); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&response_body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: StatusCode::BAD_REQUEST, + }); + } else { + // we're good to go! + tracing::info!("operation is allowed!"); + return payload.cont(); + } + } +} diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs new file mode 100644 index 000000000..eee9d7f7a --- /dev/null +++ b/lib/executor/src/plugins/examples/mod.rs @@ -0,0 +1,11 @@ +pub mod apollo_sandbox; +pub mod apq; +pub mod async_auth; +pub mod context_data; +pub mod forbid_anonymous_operations; +pub mod multipart; +pub mod one_of; +pub mod propagate_status_code; +pub mod response_cache; +pub mod root_field_limit; +pub mod subgraph_response_cache; diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs new file mode 100644 index 000000000..6a6162bc6 --- /dev/null +++ b/lib/executor/src/plugins/examples/multipart.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; + +use crate::{ + executors::common::HttpExecutionResponse, + hooks::{ + on_graphql_params::{ + GraphQLParams, OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload, + }, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; +use bytes::Bytes; +use dashmap::DashMap; +use multer::Multipart; +use serde::Serialize; +pub struct MultipartPlugin {} + +pub struct MultipartFile { + pub filename: Option, + pub content_type: Option, + pub content: Bytes, +} + +pub struct MultipartContext { + pub file_map: HashMap>, + pub files: DashMap, +} + +#[derive(Serialize)] +struct MultipartOperations<'a> { + pub query: &'a str, + pub variables: Option<&'a HashMap<&'a str, &'a sonic_rs::Value>>, + pub operation_name: Option<&'a str>, +} + +#[async_trait::async_trait] +impl RouterPlugin for MultipartPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + mut payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + if let Some(content_type) = payload.router_http_request.headers.get("content-type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.starts_with("multipart/form-data") { + let boundary = multer::parse_boundary(content_type_str).unwrap(); + let body = payload.body.clone(); + let stream = futures_util::stream::once(async move { + Ok::(Bytes::from(body.to_vec())) + }); + let mut multipart = Multipart::new(stream, boundary); + while let Some(field) = multipart.next_field().await.unwrap() { + let field_name = field.name().unwrap().to_string(); + let filename = field.file_name().map(|s| s.to_string()); + let content_type = field.content_type().map(|s| s.to_string()); + let data = field.bytes().await.unwrap(); + match field_name.as_str() { + "operations" => { + let graphql_params: GraphQLParams = + sonic_rs::from_slice(&data).unwrap(); + payload.graphql_params = Some(graphql_params); + } + "map" => { + let file_map: HashMap> = + sonic_rs::from_slice(&data).unwrap(); + payload.context.insert(MultipartContext { + file_map, + files: DashMap::new(), + }); + } + field_name => { + let mut ctx_entry = payload.context.get_mut_entry(); + let multipart_ctx: Option<&mut MultipartContext> = + ctx_entry.get_ref_mut(); + if let Some(multipart_ctx) = multipart_ctx { + let multipart_file = MultipartFile { + filename, + content_type, + content: data, + }; + multipart_ctx + .files + .insert(field_name.to_string(), multipart_file); + } + } + } + } + } + } + } + payload.cont() + } + + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + if let Some(variables) = &payload.execution_request.variables { + let ctx_ref = payload.context.get_ref_entry(); + let multipart_ctx: Option<&MultipartContext> = ctx_ref.get_ref(); + if let Some(multipart_ctx) = multipart_ctx { + let mut file_map: HashMap> = HashMap::new(); + for variable_name in variables.keys() { + // Matching variables that are file references + for (files_ref, op_refs) in &multipart_ctx.file_map { + for op_ref in op_refs { + if op_ref.starts_with(format!("variables.{}", variable_name).as_str()) { + let op_refs_in_curr_map = + file_map.entry(files_ref.to_string()).or_default(); + op_refs_in_curr_map.push(op_ref.to_string()); + } + } + } + } + if !file_map.is_empty() { + let mut form = reqwest::multipart::Form::new(); + let operations_struct = MultipartOperations { + query: payload.execution_request.query, + variables: payload.execution_request.variables.as_ref(), + operation_name: payload.execution_request.operation_name, + }; + let operations = sonic_rs::to_string(&operations_struct).unwrap(); + form = form.text("operations", operations); + let file_map_str: String = sonic_rs::to_string(&file_map).unwrap(); + form = form.text("map", file_map_str); + for (file_ref, _op_refs) in file_map { + if let Some(file_field) = multipart_ctx.files.get(&file_ref) { + let mut part = + reqwest::multipart::Part::bytes(file_field.content.to_vec()); + if let Some(file_name) = &file_field.filename { + part = part.file_name(file_name.to_string()); + } + if let Some(content_type) = &file_field.content_type { + part = part.mime_str(&content_type.to_string()).unwrap(); + } + form = form.part(file_ref, part); + } + } + let resp = reqwest::Client::new() + .post("http://example.com/graphql") + // Using query as endpoint URL + .multipart(form) + .send() + .await + .unwrap(); + let headers = resp.headers().clone(); + let status = resp.status(); + let body = resp.bytes().await.unwrap(); + payload.execution_result = Some(HttpExecutionResponse { + body, + headers, + status, + }); + } + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs new file mode 100644 index 000000000..738ece25a --- /dev/null +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -0,0 +1,213 @@ +// This example will show `@oneOf` input type validation in two steps: +// 1. During validation step +// 2. During execution step + +// We handle execution too to validate input objects at runtime as well. +/* + Let's say we have the following input type with `@oneOf` directive: + input PaymentMethod @oneOf { + creditCard: CreditCardInput + bankTransfer: BankTransferInput + paypal: PayPalInput + } + + During validation, if a variable of type `PaymentMethod` is provided with multiple fields set, + we will raise a validation error. + + ```graphql + mutation MakePayment { + makePayment(method: { + creditCard: { number: "1234", expiry: "12/24" }, + paypal: { email: "john@doe.com" } + }) { + success + } + } + ``` + + But since variables can be dynamic, we also validate during execution. If the input object has multiple fields set, + we return an error in the response. + + ```graphql + mutation MakePayment($method: PaymentMethod!) { + makePayment(method: $method) { + success + } + } + ``` + + with variables: + { + "method": { + "creditCard": { "number": "1234", "expiry": "12/24" }, + "paypal": { "email": "john@doe.com" } + } + } +*/ + +use std::{collections::BTreeMap, sync::RwLock}; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; +use graphql_parser::{ + query::Value, + schema::{Definition, TypeDefinition}, +}; +use graphql_tools::ast::visit_document; +use graphql_tools::{ + ast::{OperationVisitor, OperationVisitorContext}, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; +use sonic_rs::{json, JsonContainerTrait}; + +pub struct OneOfPlugin { + pub one_of_types: RwLock>, +} + +#[async_trait::async_trait] +impl RouterPlugin for OneOfPlugin { + // 1. During validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + let rule = OneOfValidationRule { + one_of_types: self.one_of_types.read().unwrap().clone(), + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // 2. During execution step + async fn on_execute<'exec>( + &'exec self, + payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload> { + if let (Some(variable_values), Some(variable_defs)) = ( + &payload.variable_values, + &payload.operation_for_plan.variable_definitions, + ) { + for def in variable_defs { + let variable_named_type = def.variable_type.inner_type(); + let one_of_types = self.one_of_types.read().unwrap(); + if one_of_types.contains(&variable_named_type.to_string()) { + let var_name = &def.name; + if let Some(value) = variable_values.get(var_name).and_then(|v| v.as_object()) { + let keys_num = value.len(); + if keys_num > 1 { + let err_msg = format!( + "Variable '${}' of input object type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + var_name, + variable_named_type, + keys_num + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_FIELDS_SET_IN_ONEOF" + } + }] + })) + .unwrap(), + headers: Default::default(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + } + } + payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + for def in start_payload.new_ast.definitions.iter() { + if let Definition::TypeDefinition(TypeDefinition::InputObject(input_obj)) = def { + for directive in input_obj.directives.iter() { + if directive.name == "oneOf" { + self.one_of_types + .write() + .unwrap() + .push(input_obj.name.clone()); + } + } + } + } + start_payload.cont() + } +} + +struct OneOfValidationRule { + one_of_types: Vec, +} + +impl ValidationRule for OneOfValidationRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + op_ctx: &mut OperationVisitorContext<'_>, + validation_error_context: &mut ValidationErrorContext, + ) { + visit_document( + &mut OneOfValidation { + one_of_types: self.one_of_types.clone(), + }, + op_ctx.operation, + op_ctx, + validation_error_context, + ); + } +} + +struct OneOfValidation { + one_of_types: Vec, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for OneOfValidation { + fn enter_object_value( + &mut self, + visitor_context: &mut OperationVisitorContext<'a>, + user_context: &mut ValidationErrorContext, + fields: &BTreeMap, + ) { + if let Some(TypeDefinition::InputObject(input_type)) = visitor_context.current_input_type() + { + if self.one_of_types.contains(&input_type.name) { + let mut set_fields = vec![]; + for (field_name, field_value) in fields.iter() { + if !matches!(field_value, Value::Null) { + set_fields.push(field_name.clone()); + } + } + if set_fields.len() > 1 { + let err_msg = format!( + "Input object of type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + input_type.name, + set_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_FIELDS_SET_IN_ONEOF", + locations: vec![], + message: err_msg, + }); + } + } + } + } +} diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs new file mode 100644 index 000000000..1d519904e --- /dev/null +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -0,0 +1,64 @@ +// From https://github.com/apollographql/router/blob/dev/examples/status-code-propagation/rust/src/propagate_status_code.rs + +use http::StatusCode; + +use crate::{ + hooks::{ + on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct PropagateStatusCodePlugin { + pub status_codes: Vec, +} + +pub struct PropagateStatusCodeCtx { + pub status_code: StatusCode, +} + +#[async_trait::async_trait] +impl RouterPlugin for PropagateStatusCodePlugin { + async fn on_subgraph_execute<'exec>( + &'exec self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload<'exec>> + { + payload.on_end(|payload| { + let status_code = payload.execution_result.status; + // if a response contains a status code we're watching... + if self.status_codes.contains(&status_code) { + // Checking if there is already a context entry + let mut ctx_entry = payload.context.get_mut_entry(); + let ctx: Option<&mut PropagateStatusCodeCtx> = ctx_entry.get_ref_mut(); + if let Some(ctx) = ctx { + // Update the status code if the new one is more severe (higher) + if status_code.as_u16() > ctx.status_code.as_u16() { + ctx.status_code = status_code; + } + } else { + // Insert a new context entry + let new_ctx = PropagateStatusCodeCtx { status_code }; + payload.context.insert(new_ctx); + } + } + payload.cont() + }) + } + fn on_http_request<'exec>( + &'exec self, + payload: OnHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponsePayload<'exec>> { + payload.on_end(|mut payload| { + // Checking if there is a context entry + let ctx_entry = payload.context.get_ref_entry(); + let ctx: Option<&PropagateStatusCodeCtx> = ctx_entry.get_ref(); + if let Some(ctx) = ctx { + // Update the HTTP response status code + *payload.response.response_mut().status_mut() = ctx.status_code; + } + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs new file mode 100644 index 000000000..d144b07ec --- /dev/null +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -0,0 +1,121 @@ +use dashmap::DashMap; +use http::{HeaderMap, StatusCode}; +use redis::Commands; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{EndPayload, HookResult, StartPayload}, + plugins::plugin_trait::RouterPlugin, + utils::consts::TYPENAME_FIELD_NAME, +}; + +pub struct ResponseCachePlugin { + redis_client: redis::Client, + ttl_per_type: DashMap, +} + +impl ResponseCachePlugin { + pub fn try_new(redis_url: &str) -> Result { + let redis_client = redis::Client::open(redis_url)?; + Ok(Self { + redis_client, + ttl_per_type: DashMap::new(), + }) + } +} + +#[async_trait::async_trait] +impl RouterPlugin for ResponseCachePlugin { + async fn on_execute<'exec>( + &'exec self, + payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + let key = format!( + "response_cache:{}:{:?}", + payload.query_plan, payload.variable_values + ); + if let Ok(mut conn) = self.redis_client.get_connection() { + let cached_response: Option> = conn.get(&key).ok(); + if let Some(cached_response) = cached_response { + return payload.end_response(PlanExecutionOutput { + body: cached_response, + headers: HeaderMap::new(), + status: StatusCode::OK, + }); + } + return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { + // Do not cache if there are errors + if !payload.errors.is_empty() { + return payload.cont(); + } + + if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { + // Decide on the ttl somehow + // Get the type names + let mut max_ttl = 0; + + // Imagine this code is traversing the response data to find type names + if let Some(obj) = payload.data.as_object() { + if let Some(typename) = obj + .iter() + .position(|(k, _)| k == &TYPENAME_FIELD_NAME) + .and_then(|idx| obj[idx].1.as_str()) + { + if let Some(ttl) = self.ttl_per_type.get(typename).map(|v| *v) { + max_ttl = max_ttl.max(ttl); + } + } + } + + // If no ttl found, default to 60 seconds + if max_ttl == 0 { + max_ttl = 60; + } + + // Insert the ttl into extensions for client awareness + payload + .extensions + .get_or_insert_default() + .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); + + // Set the cache with the decided ttl + let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); + } + payload.cont() + }); + } + payload.cont() + } + fn on_supergraph_reload<'a>( + &'a self, + payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + // Visit the schema and update ttl_per_type based on some directive + payload.new_ast.definitions.iter().for_each(|def| { + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { + for directive in &obj_type.directives { + if directive.name == "cacheControl" { + for arg in &directive.arguments { + if arg.0 == "maxAge" { + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { + if let Some(max_age) = max_age.as_i64() { + self.ttl_per_type + .insert(obj_type.name.clone(), max_age as u64); + } + } + } + } + } + } + } + } + }); + + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs new file mode 100644 index 000000000..24872e737 --- /dev/null +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -0,0 +1,145 @@ +use graphql_tools::{ + ast::{visit_document, OperationVisitor, OperationVisitorContext, TypeDefinitionExtension}, + static_graphql, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; +use hive_router_query_planner::ast::selection_item::SelectionItem; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::{ + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +// This example shows two ways of limiting the number of root fields in a query: +// 1. During validation step +// 2. During query planning step + +#[async_trait::async_trait] +impl RouterPlugin for RootFieldLimitPlugin { + // Using validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + let rule = RootFieldLimitRule { + max_root_fields: self.max_root_fields, + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // Or during query planning + async fn on_query_plan<'exec>( + &'exec self, + payload: OnQueryPlanStartPayload<'exec>, + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { + let mut cnt = 0; + for selection in payload + .filtered_operation_for_plan + .selection_set + .items + .iter() + { + match selection { + SelectionItem::Field(_) => { + cnt += 1; + if cnt > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + cnt, self.max_root_fields + ); + tracing::warn!("{}", err_msg); + let body = json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_ROOT_FIELDS" + } + }] + }); + // Return error + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::PAYLOAD_TOO_LARGE, + }); + } + } + SelectionItem::InlineFragment(_) => { + unreachable!("Inline fragments should have been inlined before query planning"); + } + SelectionItem::FragmentSpread(_) => { + unreachable!("Fragment spreads should have been inlined before query planning"); + } + } + } + payload.cont() + } +} + +pub struct RootFieldLimitPlugin { + max_root_fields: usize, +} + +pub struct RootFieldLimitRule { + max_root_fields: usize, +} + +struct RootFieldSelections { + max_root_fields: usize, + count: usize, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for RootFieldSelections { + fn enter_field( + &mut self, + visitor_context: &mut OperationVisitorContext, + user_context: &mut ValidationErrorContext, + field: &static_graphql::query::Field, + ) { + let parent_type_name = visitor_context.current_parent_type().map(|t| t.name()); + if parent_type_name == Some("Query") { + self.count += 1; + if self.count > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + self.count, self.max_root_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_ROOT_FIELDS", + locations: vec![field.position], + message: err_msg, + }); + } + } + } +} + +impl ValidationRule for RootFieldLimitRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + ctx: &mut OperationVisitorContext<'_>, + error_collector: &mut ValidationErrorContext, + ) { + visit_document( + &mut RootFieldSelections { + max_root_fields: self.max_root_fields, + count: 0, + }, + ctx.operation, + ctx, + error_collector, + ); + } +} diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs new file mode 100644 index 000000000..4e4b36666 --- /dev/null +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -0,0 +1,35 @@ +use dashmap::DashMap; + +use crate::{ + executors::common::HttpExecutionResponse, + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct SubgraphResponseCachePlugin { + cache: DashMap, +} + +#[async_trait::async_trait] +impl RouterPlugin for SubgraphResponseCachePlugin { + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + let key = format!( + "subgraph_response_cache:{}:{:?}", + payload.execution_request.query, payload.execution_request.variables + ); + if let Some(cached_response) = self.cache.get(&key) { + // Here payload.response is Option + // So it is bypassing the actual subgraph request + payload.execution_result = Some(cached_response.clone()); + return payload.cont(); + } + payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { + // Here payload.response is not Option + self.cache.insert(key, payload.execution_result.clone()); + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs new file mode 100644 index 000000000..64851d0fd --- /dev/null +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -0,0 +1,9 @@ +pub mod on_execute; +pub mod on_graphql_params; +pub mod on_graphql_parse; +pub mod on_graphql_validation; +pub mod on_http_request; +pub mod on_query_plan; +pub mod on_subgraph_execute; +pub mod on_subgraph_http_request; +pub mod on_supergraph_load; diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs new file mode 100644 index 000000000..b69ba3297 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -0,0 +1,36 @@ +use std::collections::HashMap; + +use hive_router_query_planner::ast::operation::OperationDefinition; +use hive_router_query_planner::planner::plan_nodes::QueryPlan; + +use crate::plugin_context::{PluginContext, RouterHttpRequest}; +use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::response::graphql_error::GraphQLError; +use crate::response::value::Value; + +pub struct OnExecuteStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, + + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, + + pub variable_values: &'exec Option>, + + pub dedupe_subgraph_requests: bool, +} + +impl<'exec> StartPayload> for OnExecuteStartPayload<'exec> {} + +pub struct OnExecuteEndPayload<'exec> { + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, + + pub response_size_estimate: usize, +} + +impl<'exec> EndPayload for OnExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs new file mode 100644 index 000000000..c69d094e4 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -0,0 +1,110 @@ +use core::fmt; + +use std::collections::HashMap; + +use ntex::util::Bytes; +use serde::{de, Deserialize, Deserializer}; +use sonic_rs::Value; + +use crate::plugin_context::PluginContext; +use crate::plugin_context::RouterHttpRequest; +use crate::plugin_trait::EndPayload; +use crate::plugin_trait::StartPayload; + +#[derive(Debug, Clone, Default)] +pub struct GraphQLParams { + pub query: Option, + pub operation_name: Option, + pub variables: HashMap, + // TODO: We don't use extensions yet, but we definitely will in the future. + #[allow(dead_code)] + pub extensions: Option>, +} + +// Workaround for https://github.com/cloudwego/sonic-rs/issues/114 + +impl<'de> Deserialize<'de> for GraphQLParams { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct GraphQLErrorExtensionsVisitor; + + impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor { + type Value = GraphQLParams; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map for GraphQLErrorExtensions") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut query = None; + let mut operation_name = None; + let mut variables: Option> = None; + let mut extensions: Option> = None; + let mut extra_params = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "query" => { + if query.is_some() { + return Err(de::Error::duplicate_field("query")); + } + query = map.next_value::>()?; + } + "operationName" => { + if operation_name.is_some() { + return Err(de::Error::duplicate_field("operationName")); + } + operation_name = map.next_value::>()?; + } + "variables" => { + if variables.is_some() { + return Err(de::Error::duplicate_field("variables")); + } + variables = map.next_value::>>()?; + } + "extensions" => { + if extensions.is_some() { + return Err(de::Error::duplicate_field("extensions")); + } + extensions = map.next_value::>>()?; + } + other => { + let value: Value = map.next_value()?; + extra_params.insert(other.to_string(), value); + } + } + } + + Ok(GraphQLParams { + query, + operation_name, + variables: variables.unwrap_or_default(), + extensions, + }) + } + } + + deserializer.deserialize_map(GraphQLErrorExtensionsVisitor) + } +} + +pub struct OnGraphQLParamsStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub body: Bytes, + pub graphql_params: Option, +} + +impl<'exec> StartPayload> for OnGraphQLParamsStartPayload<'exec> {} + +pub struct OnGraphQLParamsEndPayload<'exec> { + pub graphql_params: GraphQLParams, + pub context: &'exec PluginContext, +} + +impl<'exec> EndPayload for OnGraphQLParamsEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs new file mode 100644 index 000000000..fa29e3b9d --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -0,0 +1,22 @@ +use graphql_tools::static_graphql::query::Document; + +use crate::{ + hooks::on_graphql_params::GraphQLParams, + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; + +pub struct OnGraphQLParseStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub graphql_params: &'exec GraphQLParams, + pub document: Option, +} + +impl<'exec> StartPayload for OnGraphQLParseStartPayload<'exec> {} + +pub struct OnGraphQLParseEndPayload { + pub document: Document, +} + +impl EndPayload for OnGraphQLParseEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs new file mode 100644 index 000000000..c341a6a36 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -0,0 +1,74 @@ +use graphql_tools::{ + static_graphql::query::Document, + validation::{ + rules::{default_rules_validation_plan, ValidationRule}, + utils::ValidationError, + validate::ValidationPlan, + }, +}; +use hive_router_query_planner::state::supergraph_state::SchemaDocument; + +use crate::{ + plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; + +pub struct OnGraphQLValidationStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub schema: &'exec SchemaDocument, + pub document: &'exec Document, + default_validation_plan: &'exec ValidationPlan, + new_validation_plan: Option, + pub errors: Option>, +} + +impl<'exec> StartPayload for OnGraphQLValidationStartPayload<'exec> {} + +impl<'exec> OnGraphQLValidationStartPayload<'exec> { + pub fn new( + plugin_manager: &'exec PluginManager<'exec>, + schema: &'exec SchemaDocument, + document: &'exec Document, + default_validation_plan: &'exec ValidationPlan, + ) -> Self { + OnGraphQLValidationStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, + schema, + document, + default_validation_plan, + new_validation_plan: None, + errors: None, + } + } + + pub fn add_validation_rule(&mut self, rule: Box) { + self.new_validation_plan + .get_or_insert_with(default_rules_validation_plan) + .add_rule(rule); + } + + pub fn filter_validation_rules(&mut self, mut f: F) + where + F: FnMut(&Box) -> bool, + { + let plan = self + .new_validation_plan + .get_or_insert_with(default_rules_validation_plan); + plan.rules.retain(|rule| f(rule)); + } + + pub fn get_validation_plan(&self) -> &ValidationPlan { + match &self.new_validation_plan { + Some(plan) => plan, + None => self.default_validation_plan, + } + } +} + +pub struct OnGraphQLValidationEndPayload { + pub errors: Vec, +} + +impl EndPayload for OnGraphQLValidationEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs new file mode 100644 index 000000000..e9e857a80 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -0,0 +1,20 @@ +use ntex::web::{self, DefaultError, WebRequest}; + +use crate::{ + plugin_context::PluginContext, + plugin_trait::{EndPayload, StartPayload}, +}; + +pub struct OnHttpRequestPayload<'req> { + pub router_http_request: WebRequest, + pub context: &'req PluginContext, +} + +impl<'req> StartPayload> for OnHttpRequestPayload<'req> {} + +pub struct OnHttpResponsePayload<'req> { + pub response: web::WebResponse, + pub context: &'req PluginContext, +} + +impl<'req> EndPayload for OnHttpResponsePayload<'req> {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs new file mode 100644 index 000000000..9b2110fd7 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -0,0 +1,29 @@ +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + graph::PlannerOverrideContext, + planner::{plan_nodes::QueryPlan, Planner}, + utils::cancellation::CancellationToken, +}; + +use crate::{ + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; + +pub struct OnQueryPlanStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub filtered_operation_for_plan: &'exec OperationDefinition, + pub planner_override_context: PlannerOverrideContext, + pub cancellation_token: &'exec CancellationToken, + pub query_plan: Option, + pub planner: &'exec Planner, +} + +impl<'exec> StartPayload for OnQueryPlanStartPayload<'exec> {} + +pub struct OnQueryPlanEndPayload { + pub query_plan: QueryPlan, +} + +impl EndPayload for OnQueryPlanEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs new file mode 100644 index 000000000..18d037c10 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -0,0 +1,27 @@ +use crate::{ + executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; + +pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + + pub subgraph_name: String, + + pub execution_request: SubgraphExecutionRequest<'exec>, + pub execution_result: Option, +} + +impl<'exec> StartPayload> + for OnSubgraphExecuteStartPayload<'exec> +{ +} + +pub struct OnSubgraphExecuteEndPayload<'exec> { + pub execution_result: HttpExecutionResponse, + pub context: &'exec PluginContext, +} + +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs new file mode 100644 index 000000000..1b50f001b --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -0,0 +1,25 @@ +use bytes::Bytes; +use http::Request; +use http_body_util::Full; + +use crate::{ + executors::dedupe::SharedResponse, + plugin_trait::{EndPayload, StartPayload}, +}; + +pub struct OnSubgraphHttpRequestPayload<'exec> { + pub subgraph_name: &'exec str, + // At this point, there is no point of mutating this + pub request: Request>, + + // Early response + pub response: Option, +} + +impl<'exec> StartPayload for OnSubgraphHttpRequestPayload<'exec> {} + +pub struct OnSubgraphHttpResponsePayload { + pub response: SharedResponse, +} + +impl EndPayload for OnSubgraphHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs new file mode 100644 index 000000000..21dfbf5a5 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -0,0 +1,30 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use graphql_tools::static_graphql::schema::Document; +use hive_router_query_planner::planner::Planner; + +use crate::{ + introspection::schema::SchemaMetadata, + plugin_trait::{EndPayload, StartPayload}, + SubgraphExecutorMap, +}; + +pub struct SupergraphData { + pub metadata: SchemaMetadata, + pub planner: Planner, + pub subgraph_executor_map: SubgraphExecutorMap, +} + +pub struct OnSupergraphLoadStartPayload { + pub current_supergraph_data: Arc>>, + pub new_ast: Document, +} + +impl StartPayload for OnSupergraphLoadStartPayload {} + +pub struct OnSupergraphLoadEndPayload { + pub new_supergraph_data: SupergraphData, +} + +impl EndPayload for OnSupergraphLoadEndPayload {} diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs new file mode 100644 index 000000000..3c24ff9f2 --- /dev/null +++ b/lib/executor/src/plugins/mod.rs @@ -0,0 +1,4 @@ +pub mod examples; +pub mod hooks; +pub mod plugin_context; +pub mod plugin_trait; diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs new file mode 100644 index 000000000..d5ea9421f --- /dev/null +++ b/lib/executor/src/plugins/plugin_context.rs @@ -0,0 +1,138 @@ +use std::{ + any::{Any, TypeId}, + sync::Arc, +}; + +use dashmap::{ + mapref::one::{Ref, RefMut}, + DashMap, +}; +use http::Uri; +use ntex::router::Path; +use ntex_http::HeaderMap; + +use crate::plugin_trait::RouterPlugin; + +pub struct RouterHttpRequest<'exec> { + pub uri: &'exec Uri, + pub method: &'exec http::Method, + pub version: http::Version, + pub headers: &'exec HeaderMap, + pub path: &'exec str, + pub query_string: &'exec str, + pub match_info: &'exec Path, +} + +#[derive(Default)] +pub struct PluginContext { + inner: DashMap>, +} + +pub struct PluginContextRefEntry<'a, T> { + pub entry: Option>>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> PluginContextRefEntry<'a, T> { + pub fn get_ref(&self) -> Option<&T> { + match &self.entry { + None => None, + Some(entry) => { + let boxed_any = entry.value(); + Some(boxed_any.downcast_ref::()?) + } + } + } +} +pub struct PluginContextMutEntry<'a, T> { + pub entry: Option>>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> PluginContextMutEntry<'a, T> { + pub fn get_ref_mut(&mut self) -> Option<&mut T> { + match &mut self.entry { + None => None, + Some(entry) => { + let boxed_any = entry.value_mut(); + Some(boxed_any.downcast_mut::()?) + } + } + } +} + +impl PluginContext { + pub fn contains(&self) -> bool { + let type_id = TypeId::of::(); + self.inner.contains_key(&type_id) + } + pub fn insert(&self, value: T) -> Option> { + let type_id = TypeId::of::(); + self.inner + .insert(type_id, Box::new(value)) + .and_then(|boxed_any| boxed_any.downcast::().ok()) + } + pub fn get_ref_entry(&self) -> PluginContextRefEntry<'_, T> { + let type_id = TypeId::of::(); + let entry = self.inner.get(&type_id); + PluginContextRefEntry { + entry, + phantom: std::marker::PhantomData, + } + } + pub fn get_mut_entry<'a, T: Any + Send + Sync>(&'a self) -> PluginContextMutEntry<'a, T> { + let type_id = TypeId::of::(); + let entry = self.inner.get_mut(&type_id); + + PluginContextMutEntry { + entry, + phantom: std::marker::PhantomData, + } + } +} + +pub struct PluginManager<'req> { + pub plugins: Arc>>, + pub router_http_request: RouterHttpRequest<'req>, + pub context: Arc, +} + +#[cfg(test)] +mod tests { + #[test] + fn inserts_and_gets_immut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + let entry = ctx.get_ref_entry(); + let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 42); + } + #[test] + fn inserts_and_mutates_with_mut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + { + let mut entry = ctx.get_mut_entry(); + let ctx_mut: &mut TestCtx = entry.get_ref_mut().unwrap(); + ctx_mut.value = 100; + } + + let entry = ctx.get_ref_entry(); + let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 100); + } +} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs new file mode 100644 index 000000000..12292be11 --- /dev/null +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -0,0 +1,137 @@ +use crate::execution::plan::PlanExecutionOutput; +use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; +use crate::hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}; +use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use crate::hooks::on_graphql_validation::{ + OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +}; +use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}; +use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; +use crate::hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, +}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, +}; +use crate::hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}; + +pub struct HookResult<'exec, TStartPayload, TEndPayload> { + pub payload: TStartPayload, + pub control_flow: ControlFlowResult<'exec, TEndPayload>, +} + +pub enum ControlFlowResult<'exec, TEndPayload> { + Continue, + EndResponse(PlanExecutionOutput), + OnEnd(Box HookResult<'exec, TEndPayload, ()> + Send + 'exec>), +} + +pub trait StartPayload +where + Self: Sized, +{ + fn cont<'exec>(self) -> HookResult<'exec, Self, TEndPayload> { + HookResult { + payload: self, + control_flow: ControlFlowResult::Continue, + } + } + + fn end_response<'exec>( + self, + output: PlanExecutionOutput, + ) -> HookResult<'exec, Self, TEndPayload> { + HookResult { + payload: self, + control_flow: ControlFlowResult::EndResponse(output), + } + } + + fn on_end<'exec, F>(self, f: F) -> HookResult<'exec, Self, TEndPayload> + where + F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + Send + 'exec, + { + HookResult { + payload: self, + control_flow: ControlFlowResult::OnEnd(Box::new(f)), + } + } +} + +pub trait EndPayload +where + Self: Sized, +{ + fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { + HookResult { + payload: self, + control_flow: ControlFlowResult::Continue, + } + } + + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { + HookResult { + payload: self, + control_flow: ControlFlowResult::EndResponse(output), + } + } +} + +#[async_trait::async_trait] +pub trait RouterPlugin { + fn on_http_request<'req>( + &'req self, + start_payload: OnHttpRequestPayload<'req>, + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { + start_payload.cont() + } + async fn on_graphql_params<'exec>( + &'exec self, + start_payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + start_payload.cont() + } + async fn on_graphql_parse<'exec>( + &'exec self, + start_payload: OnGraphQLParseStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { + start_payload.cont() + } + async fn on_graphql_validation<'exec>( + &'exec self, + start_payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + start_payload.cont() + } + async fn on_query_plan<'exec>( + &'exec self, + start_payload: OnQueryPlanStartPayload<'exec>, + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { + start_payload.cont() + } + async fn on_execute<'exec>( + &'exec self, + start_payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + start_payload.cont() + } + async fn on_subgraph_execute<'exec>( + &'exec self, + start_payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + start_payload.cont() + } + async fn on_subgraph_http_request<'exec>( + &'exec self, + start_payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { + start_payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + start_payload.cont() + } +}