From d34732ac3eb19edef397656f36ef5117b2c98a76 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Mon, 13 Oct 2025 16:16:49 +0300 Subject: [PATCH 01/15] Response Cache Plugin --- Cargo.lock | 24 ++++ lib/executor/Cargo.toml | 3 + lib/executor/src/lib.rs | 2 + lib/executor/src/plugins/mod.rs | 2 + lib/executor/src/plugins/response_cache.rs | 136 +++++++++++++++++++++ lib/executor/src/plugins/traits.rs | 61 +++++++++ 6 files changed, 228 insertions(+) create mode 100644 lib/executor/src/plugins/mod.rs create mode 100644 lib/executor/src/plugins/response_cache.rs create mode 100644 lib/executor/src/plugins/traits.rs diff --git a/Cargo.lock b/Cargo.lock index d828dd8d2..d5c05dfff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2064,8 +2064,10 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "ntex", "ntex-http", "ordered-float", + "redis", "regex-automata", "ryu", "serde", @@ -4217,6 +4219,22 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redis" +version = "0.32.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "014cc767fefab6a3e798ca45112bccad9c6e0e218fbd49720042716c73cfef44" +dependencies = [ + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2 0.6.1", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4902,6 +4920,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 27f7af1bf..7dcfc03fb 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -34,6 +34,8 @@ vrl = { workspace = true } ahash = "0.8.12" regex-automata = "0.4.10" strum = { version = "0.27.2", features = ["derive"] } + +ntex = { version = "2", features = ["tokio"] } ntex-http = "0.1.15" ordered-float = "4.2.0" hyper-tls = { version = "0.6.0", features = ["vendored"] } @@ -49,6 +51,7 @@ itoa = "1.0.15" ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" +redis = "0.32.7" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 4f912a463..c245a7483 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -4,6 +4,7 @@ pub mod executors; pub mod headers; pub mod introspection; pub mod json_writer; +pub mod plugins; pub mod projection; pub mod response; pub mod utils; @@ -11,3 +12,4 @@ pub mod variables; pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; +pub use plugins::response_cache::*; diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs new file mode 100644 index 000000000..0e59883b7 --- /dev/null +++ b/lib/executor/src/plugins/mod.rs @@ -0,0 +1,2 @@ +pub mod response_cache; +pub mod traits; diff --git a/lib/executor/src/plugins/response_cache.rs b/lib/executor/src/plugins/response_cache.rs new file mode 100644 index 000000000..545d018a3 --- /dev/null +++ b/lib/executor/src/plugins/response_cache.rs @@ -0,0 +1,136 @@ +use dashmap::DashMap; +use ntex::web::HttpResponse; +use redis::Commands; +use sonic_rs::json; + +use crate::{ + plugins::traits::{ + ControlFlow, OnExecuteEnd, OnExecuteEndPayload, OnExecuteStart, OnExecuteStartPayload, + OnSchemaReload, OnSchemaReloadPayload, + }, + utils::consts::TYPENAME_FIELD_NAME, +}; + +pub struct ResponseCachePlugin { + redis_client: redis::Client, + ttl_per_type: DashMap, +} + +impl ResponseCachePlugin { + pub fn try_new(redis_url: &str) -> Result { + let redis_client = redis::Client::open(redis_url)?; + Ok(Self { + redis_client, + ttl_per_type: DashMap::new(), + }) + } +} + +pub struct ResponseCacheContext { + key: String, +} + +impl OnExecuteStart for ResponseCachePlugin { + fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow { + let key = format!( + "response_cache:{}:{:?}", + payload.query_plan, payload.variable_values + ); + payload + .router_http_request + .extensions_mut() + .insert(ResponseCacheContext { key: key.clone() }); + if let Ok(mut conn) = self.redis_client.get_connection() { + let cached_response: Option> = conn.get(&key).ok(); + if let Some(cached_response) = cached_response { + return ControlFlow::Break( + HttpResponse::Ok() + .header("Content-Type", "application/json") + .body(cached_response), + ); + } + } + ControlFlow::Continue + } +} + +impl OnExecuteEnd for ResponseCachePlugin { + fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow { + // Do not cache if there are errors + if !payload.errors.is_empty() { + return ControlFlow::Continue; + } + if let Some(key) = payload + .router_http_request + .extensions() + .get::() + .map(|ctx| &ctx.key) + { + if let Ok(mut conn) = self.redis_client.get_connection() { + if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { + // Decide on the ttl somehow + // Get the type names + let mut max_ttl = 0; + + // Imagine this code is traversing the response data to find type names + if let Some(obj) = payload.data.as_object() { + if let Some(typename) = obj + .iter() + .position(|(k, _)| k == &TYPENAME_FIELD_NAME) + .and_then(|idx| obj[idx].1.as_str()) + { + if let Some(ttl) = self.ttl_per_type.get(typename).map(|v| *v) { + max_ttl = max_ttl.max(ttl); + } + } + } + + // If no ttl found, default to 60 seconds + if max_ttl == 0 { + max_ttl = 60; + } + + // Insert the ttl into extensions for client awareness + payload + .extensions + .insert("response_cache_ttl".to_string(), json!(max_ttl)); + + // Set the cache with the decided ttl + let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); + } + } + } + ControlFlow::Continue + } +} + +impl OnSchemaReload for ResponseCachePlugin { + fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { + // Visit the schema and update ttl_per_type based on some directive + payload + .new_schema + .document + .definitions + .iter() + .for_each(|def| { + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { + for directive in &obj_type.directives { + if directive.name == "cacheControl" { + for arg in &directive.arguments { + if arg.0 == "maxAge" { + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { + if let Some(max_age) = max_age.as_i64() { + self.ttl_per_type + .insert(obj_type.name.clone(), max_age as u64); + } + } + } + } + } + } + } + } + }); + } +} diff --git a/lib/executor/src/plugins/traits.rs b/lib/executor/src/plugins/traits.rs new file mode 100644 index 000000000..4357e5eab --- /dev/null +++ b/lib/executor/src/plugins/traits.rs @@ -0,0 +1,61 @@ +use std::{collections::HashMap, sync::Arc}; + +use hive_router_query_planner::consumer_schema::ConsumerSchema; +use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use ntex::web::HttpRequest; +use ntex::web::HttpResponse; + +use crate::response::graphql_error::GraphQLError; +use crate::response::value::Value; + +pub enum ControlFlow { + Continue, + Break(HttpResponse), +} + +pub struct ExecutionResult<'exec> { + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, + pub extensions: &'exec mut Option>>, +} + +pub struct OnExecuteStartPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, + pub extensions: Option<&'exec mut sonic_rs::Value>, + + pub skip_execution: bool, + + pub variable_values: &'exec Option>, +} + +pub trait OnExecuteStart { + fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow; +} + +pub struct OnExecuteEndPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec Value<'exec>, + pub errors: &'exec Vec, + pub extensions: &'exec mut HashMap, + + pub variable_values: &'exec Option>, +} + +pub trait OnExecuteEnd { + fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow; +} + +pub struct OnSchemaReloadPayload { + pub old_schema: &'static ConsumerSchema, + pub new_schema: &'static mut ConsumerSchema, +} + +pub trait OnSchemaReload { + fn on_schema_reload(&self, payload: OnSchemaReloadPayload); +} From d6a6b7a79dd3f0c9e7aa514af1a1577fde8e9fe9 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 14 Oct 2025 14:08:52 +0300 Subject: [PATCH 02/15] Iteration --- lib/executor/src/plugins/response_cache.rs | 49 +++++++--------------- lib/executor/src/plugins/traits.rs | 25 ++++++----- 2 files changed, 26 insertions(+), 48 deletions(-) diff --git a/lib/executor/src/plugins/response_cache.rs b/lib/executor/src/plugins/response_cache.rs index 545d018a3..01719deca 100644 --- a/lib/executor/src/plugins/response_cache.rs +++ b/lib/executor/src/plugins/response_cache.rs @@ -1,12 +1,10 @@ use dashmap::DashMap; use ntex::web::HttpResponse; use redis::Commands; -use sonic_rs::json; use crate::{ plugins::traits::{ - ControlFlow, OnExecuteEnd, OnExecuteEndPayload, OnExecuteStart, OnExecuteStartPayload, - OnSchemaReload, OnSchemaReloadPayload, + ControlFlow, OnExecutePayload, OnSchemaReloadPayload, RouterPlugin }, utils::consts::TYPENAME_FIELD_NAME, }; @@ -26,20 +24,15 @@ impl ResponseCachePlugin { } } -pub struct ResponseCacheContext { - key: String, -} - -impl OnExecuteStart for ResponseCachePlugin { - fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow { +impl RouterPlugin for ResponseCachePlugin { + fn on_execute<'exec>( + &self, + payload: OnExecutePayload<'exec>, + ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { let key = format!( "response_cache:{}:{:?}", payload.query_plan, payload.variable_values ); - payload - .router_http_request - .extensions_mut() - .insert(ResponseCacheContext { key: key.clone() }); if let Ok(mut conn) = self.redis_client.get_connection() { let cached_response: Option> = conn.get(&key).ok(); if let Some(cached_response) = cached_response { @@ -49,24 +42,12 @@ impl OnExecuteStart for ResponseCachePlugin { .body(cached_response), ); } - } - ControlFlow::Continue - } -} + ControlFlow::OnEnd(Box::new(move |payload: OnExecutePayload| { + // Do not cache if there are errors + if !payload.errors.is_empty() { + return ControlFlow::Continue; + } -impl OnExecuteEnd for ResponseCachePlugin { - fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow { - // Do not cache if there are errors - if !payload.errors.is_empty() { - return ControlFlow::Continue; - } - if let Some(key) = payload - .router_http_request - .extensions() - .get::() - .map(|ctx| &ctx.key) - { - if let Ok(mut conn) = self.redis_client.get_connection() { if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { // Decide on the ttl somehow // Get the type names @@ -93,18 +74,16 @@ impl OnExecuteEnd for ResponseCachePlugin { // Insert the ttl into extensions for client awareness payload .extensions - .insert("response_cache_ttl".to_string(), json!(max_ttl)); + .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); // Set the cache with the decided ttl let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); } - } + ControlFlow::Continue + })); } ControlFlow::Continue } -} - -impl OnSchemaReload for ResponseCachePlugin { fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { // Visit the schema and update ttl_per_type based on some directive payload diff --git a/lib/executor/src/plugins/traits.rs b/lib/executor/src/plugins/traits.rs index 4357e5eab..dd3db1731 100644 --- a/lib/executor/src/plugins/traits.rs +++ b/lib/executor/src/plugins/traits.rs @@ -8,9 +8,10 @@ use ntex::web::HttpResponse; use crate::response::graphql_error::GraphQLError; use crate::response::value::Value; -pub enum ControlFlow { +pub enum ControlFlow<'a, TPayload> { Continue, Break(HttpResponse), + OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), } pub struct ExecutionResult<'exec> { @@ -19,21 +20,27 @@ pub struct ExecutionResult<'exec> { pub extensions: &'exec mut Option>>, } -pub struct OnExecuteStartPayload<'exec> { +pub struct OnExecutePayload<'exec> { pub router_http_request: &'exec HttpRequest, pub query_plan: Arc, pub data: &'exec mut Value<'exec>, pub errors: &'exec mut Vec, - pub extensions: Option<&'exec mut sonic_rs::Value>, + pub extensions: &'exec mut HashMap, pub skip_execution: bool, pub variable_values: &'exec Option>, } -pub trait OnExecuteStart { - fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow; +pub trait RouterPlugin { + fn on_execute<'exec>( + &self, + _payload: OnExecutePayload<'exec>, + ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { + ControlFlow::Continue + } + fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} } pub struct OnExecuteEndPayload<'exec> { @@ -47,15 +54,7 @@ pub struct OnExecuteEndPayload<'exec> { pub variable_values: &'exec Option>, } -pub trait OnExecuteEnd { - fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow; -} - pub struct OnSchemaReloadPayload { pub old_schema: &'static ConsumerSchema, pub new_schema: &'static mut ConsumerSchema, } - -pub trait OnSchemaReload { - fn on_schema_reload(&self, payload: OnSchemaReloadPayload); -} From 0ef2138233d76a257418121ac013c9e1b64f81d9 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 14 Oct 2025 17:09:42 +0300 Subject: [PATCH 03/15] New example --- lib/executor/src/lib.rs | 2 +- lib/executor/src/plugins/examples/mod.rs | 2 + .../plugins/{ => examples}/response_cache.rs | 7 +-- .../examples/subgraph_response_cache.rs | 32 ++++++++++ lib/executor/src/plugins/hooks/mod.rs | 3 + lib/executor/src/plugins/hooks/on_execute.rs | 33 ++++++++++ .../src/plugins/hooks/on_schema_reload.rs | 6 ++ .../src/plugins/hooks/on_subgraph_execute.rs | 47 +++++++++++++++ lib/executor/src/plugins/mod.rs | 5 +- lib/executor/src/plugins/plugin_trait.rs | 28 +++++++++ lib/executor/src/plugins/traits.rs | 60 ------------------- 11 files changed, 158 insertions(+), 67 deletions(-) create mode 100644 lib/executor/src/plugins/examples/mod.rs rename lib/executor/src/plugins/{ => examples}/response_cache.rs (95%) create mode 100644 lib/executor/src/plugins/examples/subgraph_response_cache.rs create mode 100644 lib/executor/src/plugins/hooks/mod.rs create mode 100644 lib/executor/src/plugins/hooks/on_execute.rs create mode 100644 lib/executor/src/plugins/hooks/on_schema_reload.rs create mode 100644 lib/executor/src/plugins/hooks/on_subgraph_execute.rs create mode 100644 lib/executor/src/plugins/plugin_trait.rs delete mode 100644 lib/executor/src/plugins/traits.rs diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index c245a7483..1f29c192e 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -12,4 +12,4 @@ pub mod variables; pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; -pub use plugins::response_cache::*; +pub use plugins::*; diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs new file mode 100644 index 000000000..3d54dfbed --- /dev/null +++ b/lib/executor/src/plugins/examples/mod.rs @@ -0,0 +1,2 @@ +pub mod response_cache; +pub mod subgraph_response_cache; \ No newline at end of file diff --git a/lib/executor/src/plugins/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs similarity index 95% rename from lib/executor/src/plugins/response_cache.rs rename to lib/executor/src/plugins/examples/response_cache.rs index 01719deca..ed8a106a6 100644 --- a/lib/executor/src/plugins/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -3,10 +3,9 @@ use ntex::web::HttpResponse; use redis::Commands; use crate::{ - plugins::traits::{ - ControlFlow, OnExecutePayload, OnSchemaReloadPayload, RouterPlugin - }, - utils::consts::TYPENAME_FIELD_NAME, + hooks::{on_execute::OnExecutePayload, on_schema_reload::OnSchemaReloadPayload}, plugins::plugin_trait::{ + ControlFlow, RouterPlugin + }, utils::consts::TYPENAME_FIELD_NAME }; pub struct ResponseCachePlugin { diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs new file mode 100644 index 000000000..32739df1e --- /dev/null +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -0,0 +1,32 @@ +use dashmap::DashMap; + +use crate::{hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, SubgraphExecutorResponse, SubgraphResponse}, plugin_trait::{ControlFlow, RouterPlugin}}; + +struct SubgraphResponseCachePlugin { + cache: DashMap>, +} + +impl RouterPlugin for SubgraphResponseCachePlugin { + fn on_subgraph_execute<'exec>( + &self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + let key = format!( + "subgraph_response_cache:{}:{}:{:?}", + payload.subgraph_name, payload.execution_request.operation_name.unwrap_or(""), payload.execution_request.variables + ); + if let Some(cached_response) = self.cache.get(&key) { + *payload.response = Some(SubgraphExecutorResponse::RawResponse(cached_response)); + // Return early with the cached response + return ControlFlow::Continue; + } else { + ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphExecuteEndPayload| { + let cacheable = payload.response.errors.is_none_or(|errors| errors.is_empty()); + if cacheable { + self.cache.insert(key, *payload.response); + } + ControlFlow::Continue + })) + } + } +} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs new file mode 100644 index 000000000..1954154d8 --- /dev/null +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -0,0 +1,3 @@ +pub mod on_execute; +pub mod on_schema_reload; +pub mod on_subgraph_execute; \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs new file mode 100644 index 000000000..0d5759de2 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -0,0 +1,33 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use ntex::web::HttpRequest; + +use crate::response::{value::Value}; +use crate::response::graphql_error::GraphQLError; + +pub struct OnExecutePayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, + pub extensions: &'exec mut HashMap, + + pub skip_execution: bool, + + pub variable_values: &'exec Option>, +} + +pub struct OnExecuteEndPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub query_plan: Arc, + + pub data: &'exec Value<'exec>, + pub errors: &'exec Vec, + pub extensions: &'exec mut HashMap, + + pub variable_values: &'exec Option>, +} + diff --git a/lib/executor/src/plugins/hooks/on_schema_reload.rs b/lib/executor/src/plugins/hooks/on_schema_reload.rs new file mode 100644 index 000000000..29863d964 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_schema_reload.rs @@ -0,0 +1,6 @@ +use hive_router_query_planner::consumer_schema::ConsumerSchema; + +pub struct OnSchemaReloadPayload { + pub old_schema: &'static ConsumerSchema, + pub new_schema: &'static mut ConsumerSchema, +} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs new file mode 100644 index 000000000..4fde5afb8 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -0,0 +1,47 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use hive_router_query_planner::ast::operation::SubgraphFetchOperation; +use ntex::web::HttpRequest; + +use crate::{executors::dedupe::SharedResponse, response::{graphql_error::GraphQLError, value::Value}}; + + + +pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub subgraph_name: &'exec str, + // The node that initiates this subgraph execution + pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, + // This will be tricky to implement with the current structure, + // but I'm sure we'll figure it out + pub response: &'exec mut Option>, +} + +pub struct SubgraphExecutionRequest<'exec> { + pub query: &'exec str, + // We can add the original operation here too + pub operation: &'exec SubgraphFetchOperation, + + pub dedupe: bool, + pub operation_name: Option<&'exec str>, + pub variables: Option>, + pub extensions: Option>, + pub representations: Option>, +} + +pub struct SubgraphResponse<'exec> { + pub data: Value<'exec>, + pub errors: Option>, + pub extensions: Option>>, +} + +pub struct OnSubgraphExecuteEndPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub subgraph_name: &'exec str, + // The node that initiates this subgraph execution + pub execution_request: &'exec SubgraphExecutionRequest<'exec>, + // This will be tricky to implement with the current structure, + // but I'm sure we'll figure it out + pub response: &'exec SubgraphResponse<'exec>, +} diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 0e59883b7..6c35286af 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,2 +1,3 @@ -pub mod response_cache; -pub mod traits; +pub mod examples; +pub mod plugin_trait; +pub mod hooks; \ No newline at end of file diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs new file mode 100644 index 000000000..8d5a951e7 --- /dev/null +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -0,0 +1,28 @@ +use ntex::web::HttpResponse; + +use crate::hooks::on_execute::OnExecutePayload; +use crate::hooks::on_schema_reload::OnSchemaReloadPayload; +use crate::hooks::on_subgraph_execute::OnSubgraphExecuteEndPayload; +use crate::hooks::on_subgraph_execute::OnSubgraphExecuteStartPayload; + +pub enum ControlFlow<'a, TPayload> { + Continue, + Break(HttpResponse), + OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), +} + +pub trait RouterPlugin { + fn on_execute<'exec>( + &self, + _payload: OnExecutePayload<'exec>, + ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { + ControlFlow::Continue + } + fn on_subgraph_execute<'exec>( + &self, + _payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + ControlFlow::Continue + } + fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} +} \ No newline at end of file diff --git a/lib/executor/src/plugins/traits.rs b/lib/executor/src/plugins/traits.rs deleted file mode 100644 index dd3db1731..000000000 --- a/lib/executor/src/plugins/traits.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use hive_router_query_planner::consumer_schema::ConsumerSchema; -use hive_router_query_planner::planner::plan_nodes::QueryPlan; -use ntex::web::HttpRequest; -use ntex::web::HttpResponse; - -use crate::response::graphql_error::GraphQLError; -use crate::response::value::Value; - -pub enum ControlFlow<'a, TPayload> { - Continue, - Break(HttpResponse), - OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), -} - -pub struct ExecutionResult<'exec> { - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut Option>>, -} - -pub struct OnExecutePayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, - - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut HashMap, - - pub skip_execution: bool, - - pub variable_values: &'exec Option>, -} - -pub trait RouterPlugin { - fn on_execute<'exec>( - &self, - _payload: OnExecutePayload<'exec>, - ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { - ControlFlow::Continue - } - fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} -} - -pub struct OnExecuteEndPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, - - pub data: &'exec Value<'exec>, - pub errors: &'exec Vec, - pub extensions: &'exec mut HashMap, - - pub variable_values: &'exec Option>, -} - -pub struct OnSchemaReloadPayload { - pub old_schema: &'static ConsumerSchema, - pub new_schema: &'static mut ConsumerSchema, -} From 22279bf956a4546144de14917fb8da378766a6ec Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 14 Oct 2025 17:30:13 +0300 Subject: [PATCH 04/15] Another plugin --- .../examples/subgraph_response_cache.rs | 36 +++++++++--------- lib/executor/src/plugins/hooks/mod.rs | 2 +- ...execute.rs => on_subgraph_http_request.rs} | 38 +++++++++---------- lib/executor/src/plugins/plugin_trait.rs | 15 ++++---- 4 files changed, 44 insertions(+), 47 deletions(-) rename lib/executor/src/plugins/hooks/{on_subgraph_execute.rs => on_subgraph_http_request.rs} (54%) diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 32739df1e..d4742b2ad 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,32 +1,30 @@ use dashmap::DashMap; -use crate::{hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, SubgraphExecutorResponse, SubgraphResponse}, plugin_trait::{ControlFlow, RouterPlugin}}; +use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ControlFlow, RouterPlugin}}; -struct SubgraphResponseCachePlugin { - cache: DashMap>, +pub struct SubgraphResponseCachePlugin { + cache: DashMap, } impl RouterPlugin for SubgraphResponseCachePlugin { - fn on_subgraph_execute<'exec>( - &self, - payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + fn on_subgraph_http_request<'exec>( + &'static self, + payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { let key = format!( - "subgraph_response_cache:{}:{}:{:?}", - payload.subgraph_name, payload.execution_request.operation_name.unwrap_or(""), payload.execution_request.variables + "subgraph_response_cache:{}:{:?}", + payload.execution_request.query, payload.execution_request.variables ); if let Some(cached_response) = self.cache.get(&key) { - *payload.response = Some(SubgraphExecutorResponse::RawResponse(cached_response)); - // Return early with the cached response + // Here payload.response is Option + // So it is bypassing the actual subgraph request + *payload.response = Some(cached_response.clone()); return ControlFlow::Continue; - } else { - ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphExecuteEndPayload| { - let cacheable = payload.response.errors.is_none_or(|errors| errors.is_empty()); - if cacheable { - self.cache.insert(key, *payload.response); - } - ControlFlow::Continue - })) } + ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphHttpResponsePayload| { + // Here payload.response is not Option + self.cache.insert(key, payload.response.clone()); + ControlFlow::Continue + })) } } \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 1954154d8..5a4d94c22 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,3 +1,3 @@ pub mod on_execute; pub mod on_schema_reload; -pub mod on_subgraph_execute; \ No newline at end of file +pub mod on_subgraph_http_request; \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs similarity index 54% rename from lib/executor/src/plugins/hooks/on_subgraph_execute.rs rename to lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 4fde5afb8..326e516f0 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,28 +1,34 @@ use std::collections::HashMap; -use bytes::Bytes; use hive_router_query_planner::ast::operation::SubgraphFetchOperation; +use http::{HeaderMap, Uri}; use ntex::web::HttpRequest; -use crate::{executors::dedupe::SharedResponse, response::{graphql_error::GraphQLError, value::Value}}; +use crate:: + executors::dedupe::SharedResponse +; +pub struct OnSubgraphHttpRequestPayload<'exec> { + pub router_http_request: &'exec HttpRequest, + pub subgraph_name: &'exec str, + // At this point, there is no point of mutating this + pub execution_request: &'exec SubgraphExecutionRequest<'exec>, + pub endpoint: &'exec mut Uri, + // By default, it is POST + pub method: &'exec mut http::Method, + pub headers: &'exec mut HeaderMap, + pub request_body: &'exec mut Vec, -pub struct OnSubgraphExecuteStartPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub subgraph_name: &'exec str, - // The node that initiates this subgraph execution - pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, - // This will be tricky to implement with the current structure, - // but I'm sure we'll figure it out - pub response: &'exec mut Option>, + // Early response + pub response: &'exec mut Option, } pub struct SubgraphExecutionRequest<'exec> { pub query: &'exec str, // We can add the original operation here too pub operation: &'exec SubgraphFetchOperation, - + pub dedupe: bool, pub operation_name: Option<&'exec str>, pub variables: Option>, @@ -30,18 +36,12 @@ pub struct SubgraphExecutionRequest<'exec> { pub representations: Option>, } -pub struct SubgraphResponse<'exec> { - pub data: Value<'exec>, - pub errors: Option>, - pub extensions: Option>>, -} - -pub struct OnSubgraphExecuteEndPayload<'exec> { +pub struct OnSubgraphHttpResponsePayload<'exec> { pub router_http_request: &'exec HttpRequest, pub subgraph_name: &'exec str, // The node that initiates this subgraph execution pub execution_request: &'exec SubgraphExecutionRequest<'exec>, // This will be tricky to implement with the current structure, // but I'm sure we'll figure it out - pub response: &'exec SubgraphResponse<'exec>, + pub response: &'exec mut SharedResponse, } diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 8d5a951e7..53b0720d1 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -2,13 +2,12 @@ use ntex::web::HttpResponse; use crate::hooks::on_execute::OnExecutePayload; use crate::hooks::on_schema_reload::OnSchemaReloadPayload; -use crate::hooks::on_subgraph_execute::OnSubgraphExecuteEndPayload; -use crate::hooks::on_subgraph_execute::OnSubgraphExecuteStartPayload; +use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; -pub enum ControlFlow<'a, TPayload> { +pub enum ControlFlow<'exec, TPayload> { Continue, Break(HttpResponse), - OnEnd(Box ControlFlow<'a, ()> + Send + 'a>), + OnEnd(Box ControlFlow<'exec, ()> + 'exec>), } pub trait RouterPlugin { @@ -18,10 +17,10 @@ pub trait RouterPlugin { ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { ControlFlow::Continue } - fn on_subgraph_execute<'exec>( - &self, - _payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphExecuteEndPayload<'exec>> { + fn on_subgraph_http_request<'exec>( + &'static self, + _payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { ControlFlow::Continue } fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} From 8f0c0f99cd87f2400ec1979e54adc91142b753da Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 18 Nov 2025 01:55:12 +0300 Subject: [PATCH 05/15] More --- bin/router/src/lib.rs | 7 +- bin/router/src/pipeline/coerce_variables.rs | 6 +- ...quest.rs => deserialize_graphql_params.rs} | 61 ++++------ bin/router/src/pipeline/mod.rs | 96 ++++++++++----- bin/router/src/pipeline/normalize.rs | 10 +- bin/router/src/pipeline/parser.rs | 77 ++++++++++-- bin/router/src/shared_state.rs | 3 + lib/executor/src/execution/plan.rs | 4 +- lib/executor/src/executors/common.rs | 6 +- lib/executor/src/executors/http.rs | 6 +- lib/executor/src/executors/map.rs | 4 +- lib/executor/src/plugins/examples/apq.rs | 55 +++++++++ lib/executor/src/plugins/examples/mod.rs | 3 +- .../src/plugins/examples/response_cache.rs | 32 ++--- .../examples/subgraph_response_cache.rs | 12 +- lib/executor/src/plugins/hooks/mod.rs | 8 +- .../src/plugins/hooks/on_deserialization.rs | 45 +++++++ lib/executor/src/plugins/hooks/on_execute.rs | 22 ++-- .../src/plugins/hooks/on_graphql_parse.rs | 19 +++ .../plugins/hooks/on_graphql_validation.rs | 25 ++++ .../src/plugins/hooks/on_http_request.rs | 16 +++ .../src/plugins/hooks/on_query_plan.rs | 23 ++++ .../src/plugins/hooks/on_schema_reload.rs | 6 +- .../src/plugins/hooks/on_subgraph_execute.rs | 34 ++++++ .../plugins/hooks/on_subgraph_http_request.rs | 24 +--- lib/executor/src/plugins/plugin_trait.rs | 114 +++++++++++++++--- 26 files changed, 555 insertions(+), 163 deletions(-) rename bin/router/src/pipeline/{execution_request.rs => deserialize_graphql_params.rs} (67%) create mode 100644 lib/executor/src/plugins/examples/apq.rs create mode 100644 lib/executor/src/plugins/hooks/on_deserialization.rs create mode 100644 lib/executor/src/plugins/hooks/on_graphql_parse.rs create mode 100644 lib/executor/src/plugins/hooks/on_graphql_validation.rs create mode 100644 lib/executor/src/plugins/hooks/on_http_request.rs create mode 100644 lib/executor/src/plugins/hooks/on_query_plan.rs create mode 100644 lib/executor/src/plugins/hooks/on_subgraph_execute.rs diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 6a3f7f5c0..da799b4cc 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -27,8 +27,7 @@ pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; use hive_router_config::{load_config, HiveRouterConfig}; use http::header::RETRY_AFTER; use ntex::{ - util::Bytes, - web::{self, HttpRequest}, + util::Bytes, web::{self, HttpRequest} }; use tracing::{info, warn}; @@ -121,7 +120,9 @@ pub async fn configure_app_from_config( } pub fn configure_ntex_app(cfg: &mut web::ServiceConfig) { - cfg.route("/graphql", web::to(graphql_endpoint_handler)) + cfg + .route("/graphql", web::to(graphql_endpoint_handler)) .route("/health", web::to(health_check_handler)) .route("/readiness", web::to(readiness_check_handler)); } + diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index 8c472695e..fa85223e0 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; use hive_router_plan_executor::variables::collect_variables; use hive_router_query_planner::state::supergraph_state::OperationKind; use http::Method; @@ -9,7 +10,6 @@ use sonic_rs::Value; use tracing::{error, trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::schema_state::SupergraphData; @@ -22,7 +22,7 @@ pub struct CoerceVariablesPayload { pub fn coerce_request_variables( req: &HttpRequest, supergraph: &SupergraphData, - execution_params: &mut ExecutionRequest, + graphql_params: &mut GraphQLParams, normalized_operation: &Arc, ) -> Result { if req.method() == Method::GET { @@ -37,7 +37,7 @@ pub fn coerce_request_variables( match collect_variables( &normalized_operation.operation_for_plan, - &mut execution_params.variables, + &mut graphql_params.variables, &supergraph.metadata, ) { Ok(values) => { diff --git a/bin/router/src/pipeline/execution_request.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs similarity index 67% rename from bin/router/src/pipeline/execution_request.rs rename to bin/router/src/pipeline/deserialize_graphql_params.rs index c17a6f355..1a769e5a2 100644 --- a/bin/router/src/pipeline/execution_request.rs +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -1,11 +1,10 @@ use std::collections::HashMap; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; use http::Method; use ntex::util::Bytes; use ntex::web::types::Query; use ntex::web::HttpRequest; -use serde::{Deserialize, Deserializer}; -use sonic_rs::Value; use tracing::{trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; @@ -20,36 +19,10 @@ struct GETQueryParams { pub extensions: Option, } -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ExecutionRequest { - pub query: String, - pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] - pub variables: HashMap, - // TODO: We don't use extensions yet, but we definitely will in the future. - #[allow(dead_code)] - pub extensions: Option>, -} - -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) -} - -impl TryInto for GETQueryParams { +impl TryInto for GETQueryParams { type Error = PipelineErrorVariant; - fn try_into(self) -> Result { - let query = match self.query { - Some(q) => q, - None => return Err(PipelineErrorVariant::GetMissingQueryParam("query")), - }; - + fn try_into(self) -> Result { let variables = match self.variables.as_deref() { Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { Ok(vars) => vars, @@ -70,8 +43,8 @@ impl TryInto for GETQueryParams { _ => None, }; - let execution_request = ExecutionRequest { - query, + let execution_request = GraphQLParams { + query: self.query, operation_name: self.operation_name, variables, extensions, @@ -81,13 +54,25 @@ impl TryInto for GETQueryParams { } } +pub trait GetQueryStr { + fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant>; +} + +impl GetQueryStr for GraphQLParams { + fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant> { + self.query + .as_deref() + .ok_or(PipelineErrorVariant::GetMissingQueryParam("query")) + } +} + #[inline] -pub async fn get_execution_request( - req: &mut HttpRequest, +pub fn deserialize_graphql_params( + req: &HttpRequest, body_bytes: Bytes, -) -> Result { +) -> Result { let http_method = req.method(); - let execution_request: ExecutionRequest = match *http_method { + let graphql_params: GraphQLParams = match *http_method { Method::GET => { trace!("processing GET GraphQL operation"); let query_params_str = req.uri().query().ok_or_else(|| { @@ -111,7 +96,7 @@ pub async fn get_execution_request( req.assert_json_content_type()?; let execution_request = unsafe { - sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { + sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { warn!("Failed to parse body: {}", e); req.new_pipeline_error(PipelineErrorVariant::FailedToParseBody(e)) })? @@ -130,5 +115,5 @@ pub async fn get_execution_request( } }; - Ok(execution_request) + Ok(graphql_params) } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 2b4721972..61cb353bf 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,11 +1,17 @@ use std::sync::Arc; -use hive_router_plan_executor::execution::{ - client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, +use hive_router_plan_executor::{ + execution::{ + client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, + plan::PlanExecutionOutput, + }, + hooks::on_deserialization::{ + OnDeserializationEndPayload, OnDeserializationStartPayload + }, + plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ - state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, + state::supergraph_state::OperationKind, utils::cancellation::CancellationToken }; use http::{header::CONTENT_TYPE, HeaderValue, Method}; use ntex::{ @@ -16,20 +22,9 @@ use ntex::{ use crate::{ jwt::context::JwtRequestContext, pipeline::{ - coerce_variables::coerce_request_variables, - csrf_prevention::perform_csrf_prevention, - error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, - execution::execute_plan, - execution_request::get_execution_request, - header::{ - RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON, - APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, TEXT_HTML_CONTENT_TYPE, - }, - normalize::normalize_request_with_cache, - parser::parse_operation_with_cache, - progressive_override::request_override_context, - query_plan::plan_operation_with_cache, - validation::validate_operation_with_cache, + coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, deserialize_graphql_params::{GetQueryStr, deserialize_graphql_params}, error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, header::{ + APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, RequestAccepts, TEXT_HTML_CONTENT_TYPE + }, normalize::normalize_request_with_cache, parser::parse_operation_with_cache, progressive_override::request_override_context, query_plan::plan_operation_with_cache, validation::validate_operation_with_cache }, schema_state::{SchemaState, SupergraphData}, shared_state::RouterSharedState, @@ -40,7 +35,7 @@ pub mod cors; pub mod csrf_prevention; pub mod error; pub mod execution; -pub mod execution_request; +pub mod deserialize_graphql_params; pub mod header; pub mod normalize; pub mod parser; @@ -104,17 +99,61 @@ pub async fn graphql_request_handler( #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline( - req: &mut HttpRequest, - body_bytes: Bytes, +pub async fn execute_pipeline<'req>( + req: &'req mut HttpRequest, + body: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, + shared_state: &'req Arc, schema_state: &Arc, ) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; - let mut execution_request = get_execution_request(req, body_bytes).await?; - let parser_payload = parse_operation_with_cache(req, shared_state, &execution_request).await?; + /* Handle on_deserialize hook in the plugins - START */ + let mut deserialization_end_callbacks = vec![]; + let mut deserialization_payload: OnDeserializationStartPayload<'req> = OnDeserializationStartPayload { + router_http_request: req, + body, + graphql_params: None, + }; + for plugin in &shared_state.plugins { + let result = plugin.on_deserialization(deserialization_payload); + deserialization_payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + } + ControlFlowResult::OnEnd(callback) => { + deserialization_end_callbacks.push(callback); + } + } + } + let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { + deserialize_graphql_params(req, deserialization_payload.body).expect("Failed to parse execution request") + }); + + let mut payload: OnDeserializationEndPayload<'req> = OnDeserializationEndPayload { + router_http_request: req, + graphql_params, + }; + for deserialization_end_callback in deserialization_end_callbacks { + let result = deserialization_end_callback(payload); + payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(response); + }, + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + let mut graphql_params = payload.graphql_params; + /* Handle on_deserialize hook in the plugins - END */ + + let parser_payload = parse_operation_with_cache(req, shared_state, &graphql_params).await?; validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) .await?; @@ -122,12 +161,13 @@ pub async fn execute_pipeline( req, supergraph, schema_state, - &execution_request, + &graphql_params, &parser_payload, ) .await?; + let variable_payload = - coerce_request_variables(req, supergraph, &mut execution_request, &normalize_payload)?; + coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); @@ -158,7 +198,7 @@ pub async fn execute_pipeline( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: &execution_request.query, + query: graphql_params.get_query().map_err(|err| req.new_pipeline_error(err))?, }, jwt: &jwt_request_details, }; diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index 4fc2cc5ef..f3a07ea95 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -1,6 +1,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; @@ -9,7 +10,6 @@ use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; use crate::pipeline::parser::GraphQLParserPayload; use crate::schema_state::{SchemaState, SupergraphData}; use tracing::{error, trace}; @@ -28,13 +28,13 @@ pub async fn normalize_request_with_cache( req: &HttpRequest, supergraph: &SupergraphData, schema_state: &Arc, - execution_params: &ExecutionRequest, + graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, ) -> Result, PipelineError> { - let cache_key = match &execution_params.operation_name { + let cache_key = match &graphql_params.operation_name { Some(operation_name) => { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); operation_name.hash(&mut hasher); hasher.finish() } @@ -54,7 +54,7 @@ pub async fn normalize_request_with_cache( None => match normalize_operation( &supergraph.planner.supergraph, &parser_payload.parsed_operation, - execution_params.operation_name.as_deref(), + graphql_params.operation_name.as_deref(), ) { Ok(doc) => { trace!( diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 6e8a37141..1f3357428 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,12 +2,15 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; +use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; +use crate::pipeline::deserialize_graphql_params::GetQueryStr; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; -use crate::pipeline::execution_request::ExecutionRequest; use crate::shared_state::RouterSharedState; use tracing::{error, trace}; @@ -21,11 +24,11 @@ pub struct GraphQLParserPayload { pub async fn parse_operation_with_cache( req: &HttpRequest, app_state: &Arc, - execution_params: &ExecutionRequest, + graphql_params: &GraphQLParams, ) -> Result { let cache_key = { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + graphql_params.query.hash(&mut hasher); hasher.finish() }; @@ -33,12 +36,68 @@ pub async fn parse_operation_with_cache( trace!("Found cached parsed operation for query"); cached } else { - let parsed = safe_parse_operation(&execution_params.query).map_err(|err| { - error!("Failed to parse GraphQL operation: {}", err); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) - })?; - trace!("sucessfully parsed GraphQL operation"); - let parsed_arc = Arc::new(parsed); + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartPayload { + router_http_request: req, + graphql_params, + document: None, + }; + let mut on_end_callbacks = vec![]; + for plugin in &app_state.plugins { + let result = plugin.on_graphql_parse(start_payload); + start_payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + todo!() + } + ControlFlowResult::OnEnd(callback) => { + // store the callback to be called later + on_end_callbacks.push(callback); + } + } + } + let document = match start_payload.document { + Some(parsed) => parsed, + None => { + let query_str = graphql_params.get_query().map_err(|err| { + req.new_pipeline_error(err) + })?; + let parsed = safe_parse_operation(query_str).map_err(|err| { + error!("Failed to parse GraphQL operation: {}", err); + req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) + })?; + trace!("successfully parsed GraphQL operation"); + parsed + } + }; + let mut end_payload = OnGraphQLParseEndPayload { + router_http_request: req, + graphql_params, + document, + }; + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.start_payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + todo!() + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!(); + } + } + } + let document = end_payload.document; + /* Handle on_graphql_parse hook in the plugins - END */ + + let parsed_arc = Arc::new(document); app_state .parse_cache .insert(cache_key, parsed_arc.clone()) diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index f36bda6cd..06446102a 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -3,6 +3,7 @@ use hive_router_config::HiveRouterConfig; use hive_router_plan_executor::headers::{ compile::compile_headers_plan, errors::HeaderRuleCompileError, plan::HeaderRulesPlan, }; +use hive_router_plan_executor::plugin_trait::RouterPlugin; use moka::future::Cache; use std::sync::Arc; @@ -18,6 +19,7 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, + pub plugins: Vec>, } impl RouterSharedState { @@ -36,6 +38,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, + plugins: Vec::new(), }) } } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index f86356312..6bc314516 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -19,7 +19,7 @@ use crate::{ rewrites::FetchRewriteExt, }, executors::{ - common::{HttpExecutionRequest, HttpExecutionResponse}, + common::{SubgraphExecutionRequest, HttpExecutionResponse}, map::SubgraphExecutorMap, }, headers::{ @@ -700,7 +700,7 @@ impl<'exec, 'req> Executor<'exec, 'req> { let variable_refs = select_fetch_variables(self.variable_values, node.variable_usages.as_ref()); - let mut subgraph_request = HttpExecutionRequest { + let mut subgraph_request = SubgraphExecutionRequest { query: node.operation.document_str.as_str(), dedupe: self.dedupe_subgraph_requests, operation_name: node.operation_name.as_deref(), diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index bdcd4d819..ba13b8707 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -9,7 +9,7 @@ use sonic_rs::Value; pub trait SubgraphExecutor { async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> HttpExecutionResponse; fn to_boxed_arc<'a>(self) -> Arc> @@ -26,7 +26,7 @@ pub type SubgraphExecutorBoxedArc = Arc>; pub type SubgraphRequestExtensions = HashMap; -pub struct HttpExecutionRequest<'a> { +pub struct SubgraphExecutionRequest<'a> { pub query: &'a str, pub dedupe: bool, pub operation_name: Option<&'a str>, @@ -37,7 +37,7 @@ pub struct HttpExecutionRequest<'a> { pub extensions: Option, } -impl HttpExecutionRequest<'_> { +impl SubgraphExecutionRequest<'_> { pub fn add_request_extensions_field(&mut self, key: String, value: Value) { self.extensions .get_or_insert_with(HashMap::new) diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 29b392567..5947cd7d3 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -19,7 +19,7 @@ use hyper_util::client::legacy::{connect::HttpConnector, Client}; use tokio::sync::Semaphore; use tracing::debug; -use crate::executors::common::HttpExecutionRequest; +use crate::executors::common::SubgraphExecutionRequest; use crate::executors::error::SubgraphExecutorError; use crate::response::graphql_error::GraphQLError; use crate::utils::consts::CLOSE_BRACE; @@ -76,7 +76,7 @@ impl HTTPSubgraphExecutor { fn build_request_body<'a>( &self, - execution_request: &HttpExecutionRequest<'a>, + execution_request: &SubgraphExecutionRequest<'a>, ) -> Result, SubgraphExecutorError> { let mut body = Vec::with_capacity(4096); body.put(FIRST_QUOTE_STR); @@ -212,7 +212,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { #[tracing::instrument(skip_all, fields(subgraph_name = self.subgraph_name))] async fn execute<'a>( &self, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, ) -> HttpExecutionResponse { let body = match self.build_request_body(&execution_request) { Ok(body) => body, diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index a3c297ad1..2e8ac78ae 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -30,7 +30,7 @@ use crate::{ execution::client_request_details::ClientRequestDetails, executors::{ common::{ - HttpExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, + SubgraphExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, @@ -118,7 +118,7 @@ impl SubgraphExecutorMap { pub async fn execute<'a, 'req>( &self, subgraph_name: &str, - execution_request: HttpExecutionRequest<'a>, + execution_request: SubgraphExecutionRequest<'a>, client_request: &ClientRequestDetails<'a, 'req>, ) -> HttpExecutionResponse { match self.get_or_create_executor(subgraph_name, client_request) { diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs new file mode 100644 index 000000000..7d6ac9256 --- /dev/null +++ b/lib/executor/src/plugins/examples/apq.rs @@ -0,0 +1,55 @@ +use dashmap::DashMap; +use sonic_rs::{JsonContainerTrait, JsonValueTrait}; + +use crate::{ + hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct APQPlugin { + cache: DashMap, +} + +impl RouterPlugin for APQPlugin { + fn on_deserialization<'exec>( + &'exec self, + start_payload: OnDeserializationStartPayload<'exec>, + ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> + { + start_payload.on_end(|mut end_payload| { + let persisted_query_ext = end_payload.graphql_params.extensions.as_ref() + .and_then(|ext| ext.get("persistedQuery")) + .and_then(|pq| pq.as_object()); + if let Some(persisted_query_ext) = persisted_query_ext { + match persisted_query_ext.get(&"version").and_then(|v| v.as_str()) { + Some("1") => {} + _ => { + // TODO: Error for unsupported version + return end_payload.cont(); + } + } + let sha256_hash = match persisted_query_ext.get(&"sha256Hash").and_then(|h| h.as_str()) { + Some(h) => h, + None => { + return end_payload.cont(); + } + }; + if let Some(query_param) = &end_payload.graphql_params.query { + // Store the query in the cache + self.cache.insert(sha256_hash.to_string(), query_param.to_string()); + } else { + // Try to get the query from the cache + if let Some(cached_query) = self.cache.get(sha256_hash) { + // Update the graphql_params with the cached query + end_payload.graphql_params.query = Some(cached_query.value().to_string()); + } else { + // Error + return end_payload.cont(); + } + } + } + + end_payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index 3d54dfbed..68e3e7092 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -1,2 +1,3 @@ pub mod response_cache; -pub mod subgraph_response_cache; \ No newline at end of file +pub mod subgraph_response_cache; +pub mod apq; \ No newline at end of file diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index ed8a106a6..5942e6d91 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -1,11 +1,9 @@ use dashmap::DashMap; -use ntex::web::HttpResponse; +use http::HeaderMap; use redis::Commands; use crate::{ - hooks::{on_execute::OnExecutePayload, on_schema_reload::OnSchemaReloadPayload}, plugins::plugin_trait::{ - ControlFlow, RouterPlugin - }, utils::consts::TYPENAME_FIELD_NAME + execution::plan::PlanExecutionOutput, hooks::{on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_schema_reload::OnSchemaReloadPayload}, plugin_trait::{EndPayload, HookResult, StartPayload}, plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME }; pub struct ResponseCachePlugin { @@ -25,9 +23,9 @@ impl ResponseCachePlugin { impl RouterPlugin for ResponseCachePlugin { fn on_execute<'exec>( - &self, - payload: OnExecutePayload<'exec>, - ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { + &'exec self, + payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { let key = format!( "response_cache:{}:{:?}", payload.query_plan, payload.variable_values @@ -35,16 +33,18 @@ impl RouterPlugin for ResponseCachePlugin { if let Ok(mut conn) = self.redis_client.get_connection() { let cached_response: Option> = conn.get(&key).ok(); if let Some(cached_response) = cached_response { - return ControlFlow::Break( - HttpResponse::Ok() - .header("Content-Type", "application/json") - .body(cached_response), + return payload.end_response( + + PlanExecutionOutput { + body: cached_response, + headers: HeaderMap::new(), + } ); } - ControlFlow::OnEnd(Box::new(move |payload: OnExecutePayload| { + return payload.on_end(move |payload: OnExecuteEndPayload<'exec>| { // Do not cache if there are errors if !payload.errors.is_empty() { - return ControlFlow::Continue; + return payload.cont(); } if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { @@ -78,10 +78,10 @@ impl RouterPlugin for ResponseCachePlugin { // Set the cache with the decided ttl let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); } - ControlFlow::Continue - })); + payload.cont() + }); } - ControlFlow::Continue + payload.cont() } fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { // Visit the schema and update ttl_per_type based on some directive diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index d4742b2ad..55d98a893 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,6 +1,6 @@ use dashmap::DashMap; -use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ControlFlow, RouterPlugin}}; +use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; pub struct SubgraphResponseCachePlugin { cache: DashMap, @@ -10,7 +10,7 @@ impl RouterPlugin for SubgraphResponseCachePlugin { fn on_subgraph_http_request<'exec>( &'static self, payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -19,12 +19,12 @@ impl RouterPlugin for SubgraphResponseCachePlugin { // Here payload.response is Option // So it is bypassing the actual subgraph request *payload.response = Some(cached_response.clone()); - return ControlFlow::Continue; + return payload.cont(); } - ControlFlow::OnEnd(Box::new(move |payload: OnSubgraphHttpResponsePayload| { + payload.on_end(move |payload: OnSubgraphHttpResponsePayload<'exec>| { // Here payload.response is not Option self.cache.insert(key, payload.response.clone()); - ControlFlow::Continue - })) + payload.cont() + }) } } \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 5a4d94c22..65ccf6f4d 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,3 +1,9 @@ pub mod on_execute; pub mod on_schema_reload; -pub mod on_subgraph_http_request; \ No newline at end of file +pub mod on_subgraph_http_request; +pub mod on_http_request; +pub mod on_deserialization; +pub mod on_graphql_parse; +pub mod on_graphql_validation; +pub mod on_query_plan; +pub mod on_subgraph_execute; \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_deserialization.rs b/lib/executor/src/plugins/hooks/on_deserialization.rs new file mode 100644 index 000000000..84991ff56 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_deserialization.rs @@ -0,0 +1,45 @@ +use std::collections::HashMap; + +use ntex::util::Bytes; +use serde::Deserialize; +use serde::Deserializer; +use sonic_rs::Value; + +use crate::plugin_trait::EndPayload; +use crate::plugin_trait::StartPayload; + +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GraphQLParams { + pub query: Option, + pub operation_name: Option, + #[serde(default, deserialize_with = "deserialize_null_default")] + pub variables: HashMap, + // TODO: We don't use extensions yet, but we definitely will in the future. + #[allow(dead_code)] + pub extensions: Option>, +} + +fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result +where + T: Default + Deserialize<'de>, + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + Ok(opt.unwrap_or_default()) +} + +pub struct OnDeserializationStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub body: Bytes, + pub graphql_params: Option, +} + +impl<'exec> StartPayload> for OnDeserializationStartPayload<'exec> {} + +pub struct OnDeserializationEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: GraphQLParams, +} + +impl<'exec> EndPayload for OnDeserializationEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 0d5759de2..5057075e3 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,33 +1,41 @@ use std::collections::HashMap; -use std::sync::Arc; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use ntex::web::HttpRequest; +use crate::plugin_trait::{EndPayload, StartPayload}; use crate::response::{value::Value}; use crate::response::graphql_error::GraphQLError; -pub struct OnExecutePayload<'exec> { +pub struct OnExecuteStartPayload<'exec> { pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, + pub query_plan: &'exec QueryPlan, pub data: &'exec mut Value<'exec>, pub errors: &'exec mut Vec, pub extensions: &'exec mut HashMap, - pub skip_execution: bool, + pub skip_execution: &'exec mut bool, pub variable_values: &'exec Option>, + + pub dedupe_subgraph_requests: &'exec mut bool, } +impl<'exec> StartPayload> for OnExecuteStartPayload<'exec> {} + pub struct OnExecuteEndPayload<'exec> { pub router_http_request: &'exec HttpRequest, - pub query_plan: Arc, + pub query_plan: &'exec QueryPlan, + - pub data: &'exec Value<'exec>, - pub errors: &'exec Vec, + pub data: &'exec mut Value<'exec>, + pub errors: &'exec mut Vec, pub extensions: &'exec mut HashMap, pub variable_values: &'exec Option>, + + pub dedupe_subgraph_requests: &'exec mut bool, } +impl<'exec> EndPayload for OnExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs new file mode 100644 index 000000000..8719cdac3 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -0,0 +1,19 @@ +use graphql_tools::static_graphql::query::Document; + +use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; + +pub struct OnGraphQLParseStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub document: Option, +} + +impl<'exec> StartPayload> for OnGraphQLParseStartPayload<'exec> {} + +pub struct OnGraphQLParseEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub document: Document, +} + +impl<'exec> EndPayload for OnGraphQLParseEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs new file mode 100644 index 000000000..e5ecf898f --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -0,0 +1,25 @@ +use graphql_tools::{static_graphql::query::Document, validation::{utils::ValidationError, validate::ValidationPlan}}; +use hive_router_query_planner::state::supergraph_state::SchemaDocument; + +use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; + +pub struct OnGraphQLValidationStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub schema: &'exec SchemaDocument, + pub document: &'exec Document, + pub validation_plan: &'exec mut ValidationPlan, + pub errors: &'exec mut Option> +} + +impl<'exec> StartPayload> for OnGraphQLValidationStartPayload<'exec> {} + +pub struct OnGraphQLValidationEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub graphql_params: &'exec GraphQLParams, + pub schema: &'exec SchemaDocument, + pub document: &'exec Document, + pub errors: &'exec mut Vec, +} + +impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs new file mode 100644 index 000000000..847e7465e --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -0,0 +1,16 @@ +use ntex::{http::Response, web::HttpRequest}; + +use crate::plugin_trait::{EndPayload, StartPayload}; + +pub struct OnHttpRequestPayload<'exec> { + pub router_http_request: &'exec HttpRequest, +} + +impl<'exec> StartPayload> for OnHttpRequestPayload<'exec> {} + +pub struct OnHttpResponse<'exec> { + pub router_http_request: &'exec HttpRequest, + pub response: &'exec mut Response, +} + +impl<'exec> EndPayload for OnHttpResponse<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs new file mode 100644 index 000000000..7963524ad --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -0,0 +1,23 @@ +use graphql_tools::static_graphql::query::Document; +use hive_router_query_planner::planner::{Planner, plan_nodes::QueryPlan}; + +use crate::plugin_trait::{EndPayload, StartPayload}; + +pub struct OnQueryPlanStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub document: &'exec Document, + // Other params + pub query_plan: &'exec mut Option, + pub planner: &'exec Planner, +} + +impl<'exec> StartPayload> for OnQueryPlanStartPayload<'exec> {} + +pub struct OnQueryPlanEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub document: &'exec Document, + // Other params + pub query_plan: &'exec mut QueryPlan, +} + +impl<'exec> EndPayload for OnQueryPlanEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_schema_reload.rs b/lib/executor/src/plugins/hooks/on_schema_reload.rs index 29863d964..a96d6c240 100644 --- a/lib/executor/src/plugins/hooks/on_schema_reload.rs +++ b/lib/executor/src/plugins/hooks/on_schema_reload.rs @@ -1,6 +1,6 @@ use hive_router_query_planner::consumer_schema::ConsumerSchema; -pub struct OnSchemaReloadPayload { - pub old_schema: &'static ConsumerSchema, - pub new_schema: &'static mut ConsumerSchema, +pub struct OnSchemaReloadPayload<'a> { + pub old_schema: &'a ConsumerSchema, + pub new_schema: &'a mut ConsumerSchema, } diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs new file mode 100644 index 000000000..167340bc8 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -0,0 +1,34 @@ +use bytes::Bytes; +use hive_router_query_planner::planner::plan_nodes::FetchNode; + +use crate::{executors::common::{SubgraphExecutionRequest, SubgraphExecutorBoxedArc}, plugin_trait::{EndPayload, StartPayload}, response::subgraph_response::SubgraphResponse}; + + +pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub executor: &'exec SubgraphExecutorBoxedArc, + pub subgraph_name: &'exec str, + + pub node: &'exec mut FetchNode, + pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, + pub response: &'exec mut Option>, +} + +impl<'exec> StartPayload> for OnSubgraphExecuteStartPayload<'exec> {} + +pub enum SubgraphExecutorResponse<'exec> { + Bytes(Bytes), + SubgraphResponse(SubgraphResponse<'exec>), +} + +pub struct OnSubgraphExecuteEndPayload<'exec> { + pub router_http_request: &'exec ntex::web::HttpRequest, + pub executor: &'exec SubgraphExecutorBoxedArc, + pub subgraph_name: &'exec str, + + pub node: &'exec FetchNode, + pub execution_request: &'exec SubgraphExecutionRequest<'exec>, + pub response: &'exec mut SubgraphExecutorResponse<'exec>, +} + +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 326e516f0..ac720b870 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,11 +1,8 @@ -use std::collections::HashMap; - -use hive_router_query_planner::ast::operation::SubgraphFetchOperation; use http::{HeaderMap, Uri}; use ntex::web::HttpRequest; -use crate:: - executors::dedupe::SharedResponse +use crate::{ + executors::{common::SubgraphExecutionRequest, dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} ; pub struct OnSubgraphHttpRequestPayload<'exec> { @@ -24,24 +21,13 @@ pub struct OnSubgraphHttpRequestPayload<'exec> { pub response: &'exec mut Option, } -pub struct SubgraphExecutionRequest<'exec> { - pub query: &'exec str, - // We can add the original operation here too - pub operation: &'exec SubgraphFetchOperation, - - pub dedupe: bool, - pub operation_name: Option<&'exec str>, - pub variables: Option>, - pub extensions: Option>, - pub representations: Option>, -} +impl<'exec> StartPayload> for OnSubgraphHttpRequestPayload<'exec> {} pub struct OnSubgraphHttpResponsePayload<'exec> { pub router_http_request: &'exec HttpRequest, pub subgraph_name: &'exec str, - // The node that initiates this subgraph execution pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - // This will be tricky to implement with the current structure, - // but I'm sure we'll figure it out pub response: &'exec mut SharedResponse, } + +impl<'exec> EndPayload for OnSubgraphHttpResponsePayload<'exec> {} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 53b0720d1..220c5b88c 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -1,27 +1,113 @@ -use ntex::web::HttpResponse; - -use crate::hooks::on_execute::OnExecutePayload; +use crate::execution::plan::PlanExecutionOutput; +use crate::hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}; +use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; +use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use crate::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponse}; +use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; use crate::hooks::on_schema_reload::OnSchemaReloadPayload; use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; -pub enum ControlFlow<'exec, TPayload> { +pub struct HookResult<'exec, TStartPayload, TEndPayload> { + pub start_payload: TStartPayload, + pub control_flow: ControlFlowResult<'exec, TEndPayload>, +} + +pub enum ControlFlowResult<'exec, TEndPayload> { Continue, - Break(HttpResponse), - OnEnd(Box ControlFlow<'exec, ()> + 'exec>), + EndResponse(PlanExecutionOutput), + OnEnd(Box HookResult<'exec, TEndPayload, ()> + 'exec>), +} + +pub trait StartPayload + where Self: Sized + { + + fn cont<'exec>(self) -> HookResult<'exec, Self, TEndPayload> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::Continue, + } + } + + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, TEndPayload> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::EndResponse(output), + } + } + + fn on_end<'exec, F>(self, f: F) -> HookResult<'exec, Self, TEndPayload> + where F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + 'exec, + { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::OnEnd(Box::new(f)), + } + } +} + +pub trait EndPayload + where Self: Sized + { + fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::Continue, + } + } + + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { + HookResult { + start_payload: self, + control_flow: ControlFlowResult::EndResponse(output), + } + } } +// Add sync send etc pub trait RouterPlugin { - fn on_execute<'exec>( + fn on_http_request<'exec>( + &self, + start_payload: OnHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponse<'exec>> { + start_payload.cont() + } + fn on_deserialization<'exec>( + &'exec self, + start_payload: OnDeserializationStartPayload<'exec>, + ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> { + start_payload.cont() + } + fn on_graphql_parse<'exec>( &self, - _payload: OnExecutePayload<'exec>, - ) -> ControlFlow<'exec, OnExecutePayload<'exec>> { - ControlFlow::Continue + start_payload: OnGraphQLParseStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload<'exec>> { + start_payload.cont() + } + fn on_graphql_validation<'exec>( + &self, + start_payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload<'exec>> { + start_payload.cont() + } + fn on_query_plan<'exec>( + &self, + start_payload: OnQueryPlanStartPayload<'exec>, + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { + start_payload.cont() + } + fn on_execute<'exec>( + &'exec self, + start_payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + start_payload.cont() } fn on_subgraph_http_request<'exec>( &'static self, - _payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> ControlFlow<'exec, OnSubgraphHttpResponsePayload<'exec>> { - ControlFlow::Continue + start_payload: OnSubgraphHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { + start_payload.cont() } - fn on_schema_reload(&self, _payload: OnSchemaReloadPayload) {} + fn on_schema_reload<'a>(&'a self, _start_payload: OnSchemaReloadPayload) {} } \ No newline at end of file From 0bbe2f8684f6023a49b36174f46a1213d9684d4b Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 18 Nov 2025 19:59:16 +0300 Subject: [PATCH 06/15] More --- Cargo.lock | 1 + bin/router/src/lib.rs | 4 +- bin/router/src/pipeline/coerce_variables.rs | 4 +- .../pipeline/deserialize_graphql_params.rs | 2 +- bin/router/src/pipeline/error.rs | 2 +- bin/router/src/pipeline/execution.rs | 4 +- bin/router/src/pipeline/mod.rs | 41 ++++--- bin/router/src/pipeline/normalize.rs | 5 +- bin/router/src/pipeline/parser.rs | 30 +++-- bin/router/src/pipeline/query_plan.rs | 101 ++++++++++++++-- bin/router/src/pipeline/validation.rs | 70 ++++++++++- bin/router/src/schema_state.rs | 79 +++++++++--- bin/router/src/shared_state.rs | 4 +- lib/executor/Cargo.toml | 1 + .../src/execution/client_request_details.rs | 2 +- lib/executor/src/execution/error.rs | 2 +- lib/executor/src/execution/plan.rs | 92 ++++++++++---- lib/executor/src/executors/http.rs | 5 +- lib/executor/src/executors/map.rs | 81 +++++++++++-- lib/executor/src/plugins/examples/apq.rs | 24 ++-- .../src/plugins/examples/response_cache.rs | 29 +++-- .../examples/subgraph_response_cache.rs | 14 +-- lib/executor/src/plugins/hooks/mod.rs | 4 +- lib/executor/src/plugins/hooks/on_execute.rs | 24 ++-- ...eserialization.rs => on_graphql_params.rs} | 9 +- .../src/plugins/hooks/on_graphql_parse.rs | 2 +- .../plugins/hooks/on_graphql_validation.rs | 64 ++++++++-- .../src/plugins/hooks/on_http_request.rs | 2 +- .../src/plugins/hooks/on_query_plan.rs | 18 +-- .../src/plugins/hooks/on_schema_reload.rs | 6 - .../src/plugins/hooks/on_subgraph_execute.rs | 33 ++--- .../plugins/hooks/on_subgraph_http_request.rs | 26 ++-- .../src/plugins/hooks/on_supergraph_load.rs | 27 +++++ lib/executor/src/plugins/plugin_trait.rs | 114 +++++++++++------- 34 files changed, 655 insertions(+), 271 deletions(-) rename lib/executor/src/plugins/hooks/{on_deserialization.rs => on_graphql_params.rs} (76%) delete mode 100644 lib/executor/src/plugins/hooks/on_schema_reload.rs create mode 100644 lib/executor/src/plugins/hooks/on_supergraph_load.rs diff --git a/Cargo.lock b/Cargo.lock index d5c05dfff..b81088873 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2046,6 +2046,7 @@ name = "hive-router-plan-executor" version = "6.0.1" dependencies = [ "ahash", + "arc-swap", "async-trait", "bumpalo", "bytes", diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index da799b4cc..fb19df5c8 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -111,10 +111,10 @@ pub async fn configure_app_from_config( }; let router_config_arc = Arc::new(router_config); + let shared_state = Arc::new(RouterSharedState::new(router_config_arc.clone(), jwt_runtime)?); let schema_state = - SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone()).await?; + SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone(), shared_state.clone()).await?; let schema_state_arc = Arc::new(schema_state); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc, jwt_runtime)?); Ok((shared_state, schema_state_arc)) } diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index fa85223e0..b159f244e 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use std::sync::Arc; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::variables::collect_variables; use hive_router_query_planner::state::supergraph_state::OperationKind; use http::Method; @@ -11,7 +12,6 @@ use tracing::{error, trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; #[derive(Clone, Debug)] pub struct CoerceVariablesPayload { diff --git a/bin/router/src/pipeline/deserialize_graphql_params.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs index 1a769e5a2..3c0eb5f12 100644 --- a/bin/router/src/pipeline/deserialize_graphql_params.rs +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use http::Method; use ntex::util::Bytes; use ntex::web::types::Query; diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index eec36ea76..71e0a197d 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -78,7 +78,7 @@ pub enum PipelineErrorVariant { #[error("Failed to execute a plan: {0}")] PlanExecutionError(PlanExecutionError), #[error("Failed to produce a plan: {0}")] - PlannerError(Arc), + PlannerError(PlannerError), #[error(transparent)] LabelEvaluationError(LabelEvaluationError), diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 42ace79ce..56f92fece 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -4,12 +4,12 @@ use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::normalize::GraphQLNormalizationPayload; -use crate::schema_state::SupergraphData; use crate::shared_state::RouterSharedState; use hive_router_plan_executor::execute_query_plan; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; @@ -85,6 +85,7 @@ pub async fn execute_plan( }; execute_query_plan(QueryPlanExecutionContext { + router_http_request: req, query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, @@ -95,6 +96,7 @@ pub async fn execute_plan( operation_type_name: normalized_payload.root_type_name, jwt_auth_forwarding: &jwt_forward_plan, executors: &supergraph.subgraph_executor_map, + plugins: &app_state.plugins, }) .await .map_err(|err| { diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 61cb353bf..ddc949c4c 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -5,9 +5,9 @@ use hive_router_plan_executor::{ client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, plan::PlanExecutionOutput, }, - hooks::on_deserialization::{ - OnDeserializationEndPayload, OnDeserializationStartPayload - }, + hooks::{on_graphql_params::{ + OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload + }, on_supergraph_load::SupergraphData}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ @@ -24,9 +24,9 @@ use crate::{ pipeline::{ coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, deserialize_graphql_params::{GetQueryStr, deserialize_graphql_params}, error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, header::{ APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, RequestAccepts, TEXT_HTML_CONTENT_TYPE - }, normalize::normalize_request_with_cache, parser::parse_operation_with_cache, progressive_override::request_override_context, query_plan::plan_operation_with_cache, validation::validate_operation_with_cache + }, normalize::normalize_request_with_cache, parser::{ParseResult, parse_operation_with_cache}, progressive_override::request_override_context, query_plan::{QueryPlanResult, plan_operation_with_cache}, validation::validate_operation_with_cache }, - schema_state::{SchemaState, SupergraphData}, + schema_state::{SchemaState}, shared_state::RouterSharedState, }; @@ -110,14 +110,14 @@ pub async fn execute_pipeline<'req>( /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; - let mut deserialization_payload: OnDeserializationStartPayload<'req> = OnDeserializationStartPayload { + let mut deserialization_payload: OnGraphQLParamsStartPayload<'req> = OnGraphQLParamsStartPayload { router_http_request: req, body, graphql_params: None, }; - for plugin in &shared_state.plugins { - let result = plugin.on_deserialization(deserialization_payload); - deserialization_payload = result.start_payload; + for plugin in shared_state.plugins.as_ref() { + let result = plugin.on_graphql_params(deserialization_payload); + deserialization_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { @@ -132,13 +132,12 @@ pub async fn execute_pipeline<'req>( deserialize_graphql_params(req, deserialization_payload.body).expect("Failed to parse execution request") }); - let mut payload: OnDeserializationEndPayload<'req> = OnDeserializationEndPayload { - router_http_request: req, + let mut payload = OnGraphQLParamsEndPayload { graphql_params, }; for deserialization_end_callback in deserialization_end_callbacks { let result = deserialization_end_callback(payload); - payload = result.start_payload; + payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { @@ -153,7 +152,13 @@ pub async fn execute_pipeline<'req>( let mut graphql_params = payload.graphql_params; /* Handle on_deserialize hook in the plugins - END */ - let parser_payload = parse_operation_with_cache(req, shared_state, &graphql_params).await?; + let parser_payload = match parse_operation_with_cache(req, shared_state, &graphql_params).await? { + ParseResult::Payload(payload) => payload, + ParseResult::Response(response) => { + return Ok(response); + } + }; + validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) .await?; @@ -209,15 +214,21 @@ pub async fn execute_pipeline<'req>( ) .map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?; - let query_plan_payload = plan_operation_with_cache( + let query_plan_payload = match plan_operation_with_cache( req, supergraph, schema_state, &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, + shared_state, ) - .await?; + .await? { + QueryPlanResult::QueryPlan(query_plan_payload) => query_plan_payload, + QueryPlanResult::Response(response) => { + return Ok(response); + } + }; let execution_result = execute_plan( req, diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index f3a07ea95..c57e2d566 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -1,7 +1,8 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; @@ -11,7 +12,7 @@ use xxhash_rust::xxh3::Xxh3; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::{SchemaState}; use tracing::{error, trace}; #[derive(Debug)] diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 1f3357428..18365a3e2 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,7 +2,8 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::hooks::on_deserialization::GraphQLParams; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; @@ -20,12 +21,17 @@ pub struct GraphQLParserPayload { pub cache_key: u64, } +pub enum ParseResult { + Payload(GraphQLParserPayload), + Response(PlanExecutionOutput), +} + #[inline] pub async fn parse_operation_with_cache( req: &HttpRequest, app_state: &Arc, graphql_params: &GraphQLParams, -) -> Result { +) -> Result { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); @@ -43,15 +49,15 @@ pub async fn parse_operation_with_cache( document: None, }; let mut on_end_callbacks = vec![]; - for plugin in &app_state.plugins { + for plugin in app_state.plugins.as_ref() { let result = plugin.on_graphql_parse(start_payload); - start_payload = result.start_payload; + start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { // continue to next plugin } ControlFlowResult::EndResponse(response) => { - todo!() + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(callback) => { // store the callback to be called later @@ -80,13 +86,13 @@ pub async fn parse_operation_with_cache( }; for callback in on_end_callbacks { let result = callback(end_payload); - end_payload = result.start_payload; + end_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { // continue to next callback } ControlFlowResult::EndResponse(response) => { - todo!() + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -105,8 +111,10 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok(GraphQLParserPayload { - parsed_operation, - cache_key, - }) + Ok( + ParseResult::Payload(GraphQLParserPayload { + parsed_operation, + cache_key, + }) + ) } diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index b2f730be7..58b4e7475 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -4,12 +4,28 @@ use std::sync::Arc; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::{SchemaState}; +use crate::RouterSharedState; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanStartPayload; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; +use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; +pub enum QueryPlanResult { + QueryPlan(Arc), + Response(PlanExecutionOutput), +} + +pub enum QueryPlanGetterError { + Planner(PlannerError), + Response(PlanExecutionOutput), +} + #[inline] pub async fn plan_operation_with_cache( req: &HttpRequest, @@ -18,7 +34,8 @@ pub async fn plan_operation_with_cache( normalized_operation: &Arc, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, -) -> Result, PipelineError> { + app_state: &Arc, +) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -38,20 +55,80 @@ pub async fn plan_operation_with_cache( })); } - supergraph - .planner - .plan_from_normalized_operation( - filtered_operation_for_plan, - (&request_override_context.clone()).into(), - cancellation_token, - ) - .map(Arc::new) + /* Handle on_query_plan hook in the plugins - START */ + let mut start_payload = OnQueryPlanStartPayload { + router_http_request: req, + filtered_operation_for_plan, + planner_override_context: (&request_override_context.clone()).into(), + cancellation_token, + query_plan: None, + planner: &supergraph.planner, + }; + + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_query_plan(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + let query_plan = match start_payload.query_plan { + Some(plan) => plan, + None => supergraph + .planner + .plan_from_normalized_operation( + filtered_operation_for_plan, + (&request_override_context.clone()).into(), + cancellation_token, + ) + .map_err(|e| QueryPlanGetterError::Planner(e))?, + }; + + let mut end_payload = hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanEndPayload { + router_http_request: req, + filtered_operation_for_plan, + planner_override_context: (&request_override_context.clone()).into(), + cancellation_token, + query_plan, + planner: &supergraph.planner, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Err(QueryPlanGetterError::Response(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + } + } + } + + Ok(Arc::new(end_payload.query_plan)) + /* Handle on_query_plan hook in the plugins - END */ }) .await; match plan_result { - Ok(plan) => Ok(plan), - Err(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), + Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), + Err(e) => match e.as_ref() { + QueryPlanGetterError::Planner(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), + QueryPlanGetterError::Response(response) => Ok(QueryPlanResult::Response(response.clone())), + }, } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 85d44c2f1..f97aa1661 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -2,9 +2,13 @@ use std::sync::Arc; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState, SupergraphData}; +use crate::schema_state::{SchemaState}; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_trait::ControlFlowResult; use ntex::web::HttpRequest; use tracing::{error, trace}; @@ -15,7 +19,7 @@ pub async fn validate_operation_with_cache( schema_state: &Arc, app_state: &Arc, parser_payload: &GraphQLParserPayload, -) -> Result<(), PipelineError> { +) -> Result, PipelineError> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -36,13 +40,67 @@ pub async fn validate_operation_with_cache( "validation result of hash {} does not exists in cache", parser_payload.cache_key ); - - let res = validate( + + /* Handle on_graphql_validate hook in the plugins - START */ + let mut start_payload = OnGraphQLValidationStartPayload::new( + req, consumer_schema_ast, &parser_payload.parsed_operation, &app_state.validation_plan, ); - let arc_res = Arc::new(res); + let mut on_end_callbacks = vec![]; + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_graphql_validation(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let errors = match start_payload.errors { + Some(errors) => errors, + None => { + validate( + consumer_schema_ast, + &start_payload.document, + start_payload.get_validation_plan(), + ) + } + }; + + let mut end_payload = OnGraphQLValidationEndPayload { + router_http_request: req, + schema: consumer_schema_ast, + document: &parser_payload.parsed_operation, + errors, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + return Ok(Some(response)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + } + } + } + /* Handle on_graphql_validate hook in the plugins - END */ + + let arc_res = Arc::new(end_payload.errors); schema_state .validate_cache @@ -64,5 +122,5 @@ pub async fn validate_operation_with_cache( ); } - Ok(()) + Ok(None) } diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index f14cc6cf0..3f79a06aa 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -1,11 +1,9 @@ use arc_swap::{ArcSwap, Guard}; use async_trait::async_trait; -use graphql_tools::validation::utils::ValidationError; +use graphql_tools::{static_graphql::schema::Document, validation::utils::ValidationError}; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ - executors::error::SubgraphExecutorError, - introspection::schema::{SchemaMetadata, SchemaWithMetadata}, - SubgraphExecutorMap, + SubgraphExecutorMap, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData}, introspection::schema::SchemaWithMetadata, plugin_trait::{ControlFlowResult, RouterPlugin} }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::{ @@ -20,12 +18,10 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, trace}; use crate::{ - background_tasks::{BackgroundTask, BackgroundTasksManager}, - pipeline::normalize::GraphQLNormalizationPayload, - supergraph::{ + RouterSharedState, background_tasks::{BackgroundTask, BackgroundTasksManager}, pipeline::normalize::GraphQLNormalizationPayload, supergraph::{ base::{LoadSupergraphError, ReloadSupergraphResult, SupergraphLoader}, resolve_from_config, - }, + } }; pub struct SchemaState { @@ -35,12 +31,6 @@ pub struct SchemaState { pub normalize_cache: Cache>, } -pub struct SupergraphData { - pub metadata: SchemaMetadata, - pub planner: Planner, - pub subgraph_executor_map: SubgraphExecutorMap, -} - #[derive(Debug, thiserror::Error)] pub enum SupergraphManagerError { #[error("Failed to load supergraph: {0}")] @@ -65,6 +55,7 @@ impl SchemaState { pub async fn new_from_config( bg_tasks_manager: &mut BackgroundTasksManager, router_config: Arc, + app_state: Arc ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new(&router_config.supergraph, tx)?; @@ -85,9 +76,58 @@ impl SchemaState { while let Some(new_sdl) = rx.recv().await { debug!("Received new supergraph SDL, building new supergraph state..."); - match Self::build_data(router_config.clone(), &new_sdl) { - Ok(new_data) => { - swappable_data_spawn_clone.store(Arc::new(Some(new_data))); + let new_ast = parse_schema(&new_sdl); + + let mut start_payload = OnSupergraphLoadStartPayload { + current_supergraph_data: swappable_data_spawn_clone.clone(), + new_ast, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in app_state.plugins.as_ref() { + let result = plugin.on_supergraph_reload(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + }, + ControlFlowResult::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + }, + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let new_ast = start_payload.new_ast; + + match Self::build_data(router_config.clone(), &new_ast, app_state.plugins.clone()) { + Ok(new_supergraph_data) => { + let mut end_payload = OnSupergraphLoadEndPayload { + new_supergraph_data, + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + }, + ControlFlowResult::EndResponse(_) => { + unreachable!("Plugins should not end supergraph reload processing"); + }, + ControlFlowResult::OnEnd(_) => { + unreachable!("End callbacks should not register further end callbacks"); + } + } + } + + let new_supergraph_data = end_payload.new_supergraph_data; + + swappable_data_spawn_clone.store(Arc::new(Some(new_supergraph_data))); debug!("Supergraph updated successfully"); task_plan_cache.invalidate_all(); @@ -112,15 +152,16 @@ impl SchemaState { fn build_data( router_config: Arc, - supergraph_sdl: &str, + parsed_supergraph_sdl: &Document, + plugins: Arc>>, ) -> Result { - let parsed_supergraph_sdl = parse_schema(supergraph_sdl); let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl); let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; let metadata = planner.consumer_schema.schema_metadata(); let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, router_config, + plugins.clone(), )?; Ok(SupergraphData { diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 06446102a..877ffa0e3 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -19,7 +19,7 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, - pub plugins: Vec>, + pub plugins: Arc>>, } impl RouterSharedState { @@ -38,7 +38,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, - plugins: Vec::new(), + plugins: Arc::new(vec![]), }) } } diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 7dcfc03fb..39d51ee7e 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -35,6 +35,7 @@ ahash = "0.8.12" regex-automata = "0.4.10" strum = { version = "0.27.2", features = ["derive"] } +arc-swap = "1.7.1" ntex = { version = "2", features = ["tokio"] } ntex-http = "0.1.15" ordered-float = "4.2.0" diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 6985376cc..35540dab2 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap}; use bytes::Bytes; use http::Method; diff --git a/lib/executor/src/execution/error.rs b/lib/executor/src/execution/error.rs index 63460eb48..aaaf9a729 100644 --- a/lib/executor/src/execution/error.rs +++ b/lib/executor/src/execution/error.rs @@ -116,7 +116,7 @@ impl IntoPlanExecutionError for Result { let kind = PlanExecutionErrorKind::ProjectionFailure(source); PlanExecutionError::new(kind, context) }) - } + } } impl IntoPlanExecutionError for Result { diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 6bc314516..520f429c8 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -7,48 +7,44 @@ use hive_router_query_planner::planner::plan_nodes::{ QueryPlan, SequenceNode, }; use http::HeaderMap; +use ntex::web::HttpRequest; use serde::Deserialize; use sonic_rs::ValueRef; use crate::{ - context::ExecutionContext, - execution::{ + context::ExecutionContext, execution::{ client_request_details::ClientRequestDetails, error::{IntoPlanExecutionError, LazyPlanContext, PlanExecutionError}, jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt, - }, - executors::{ - common::{SubgraphExecutionRequest, HttpExecutionResponse}, + }, executors::{ + common::{HttpExecutionResponse, SubgraphExecutionRequest}, map::SubgraphExecutorMap, - }, - headers::{ + }, headers::{ plan::HeaderRulesPlan, request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, - }, - introspection::{ - resolve::{resolve_introspection, IntrospectionContext}, + }, hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, introspection::{ + resolve::{IntrospectionContext, resolve_introspection}, schema::SchemaMetadata, - }, - projection::{ + }, plugin_trait::{ControlFlowResult, RouterPlugin}, projection::{ plan::FieldProjectionPlan, - request::{project_requires, RequestProjectionContext}, + request::{RequestProjectionContext, project_requires}, response::project_by_operation, - }, - response::{ + }, response::{ graphql_error::{GraphQLError, GraphQLErrorExtensions, GraphQLErrorPath}, merge::deep_merge, subgraph_response::SubgraphResponse, value::Value, - }, - utils::{ + }, utils::{ consts::{CLOSE_BRACKET, OPEN_BRACKET}, traverse::{traverse_and_callback, traverse_and_callback_mut}, - }, + } }; pub struct QueryPlanExecutionContext<'exec, 'req> { + pub router_http_request: &'exec HttpRequest, + pub plugins: &'exec Vec>, pub query_plan: &'exec QueryPlan, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, @@ -61,6 +57,7 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub jwt_auth_forwarding: &'exec Option, } +#[derive(Clone)] pub struct PlanExecutionOutput { pub body: Vec, pub headers: HeaderMap, @@ -75,6 +72,36 @@ pub async fn execute_query_plan<'exec, 'req>( Value::Null }; + let dedupe_subgraph_requests = ctx.operation_type_name == "Query"; + + let mut start_payload = OnExecuteStartPayload { + router_http_request: ctx.router_http_request, + query_plan: ctx.query_plan, + data: init_value, + errors: Vec::new(), + extensions: ctx.extensions.clone(), + variable_values: ctx.variable_values, + dedupe_subgraph_requests, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in ctx.plugins { + let result = plugin.on_execute(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ }, + ControlFlowResult::EndResponse(response) => { + return Ok(response); + }, + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let init_value = start_payload.data; + let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); let executor = Executor::new( ctx.variable_values, @@ -100,15 +127,36 @@ pub async fn execute_query_plan<'exec, 'req>( affected_path: || None, })?; - let final_response = &exec_ctx.final_response; + let mut end_payload = OnExecuteEndPayload { + data: exec_ctx.final_response, + errors: exec_ctx.errors, + extensions: start_payload.extensions, + response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ }, + ControlFlowResult::EndResponse(response) => { + return Ok(response); + }, + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + let body = project_by_operation( - final_response, - exec_ctx.errors, + &end_payload.data, + end_payload.errors, &ctx.extensions, ctx.operation_type_name, ctx.projection_plan, ctx.variable_values, - exec_ctx.response_storage.estimate_final_response_size(), + end_payload.response_size_estimate, ) .with_plan_context(LazyPlanContext { subgraph_name: || None, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 5947cd7d3..c09e01067 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; +use crate::plugin_trait::RouterPlugin; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -28,7 +29,6 @@ use crate::utils::consts::COMMA; use crate::utils::consts::QUOTE; use crate::{executors::common::SubgraphExecutor, json_writer::write_and_escape_string}; -#[derive(Debug)] pub struct HTTPSubgraphExecutor { pub subgraph_name: String, pub endpoint: http::Uri, @@ -37,6 +37,7 @@ pub struct HTTPSubgraphExecutor { pub semaphore: Arc, pub config: Arc, pub in_flight_requests: Arc>, ABuildHasher>>, + pub plugins: Arc>>, } const FIRST_VARIABLE_STR: &[u8] = b",\"variables\":{"; @@ -52,6 +53,7 @@ impl HTTPSubgraphExecutor { semaphore: Arc, config: Arc, in_flight_requests: Arc>, ABuildHasher>>, + plugins: Arc>>, ) -> Self { let mut header_map = HeaderMap::new(); header_map.insert( @@ -71,6 +73,7 @@ impl HTTPSubgraphExecutor { semaphore, config, in_flight_requests, + plugins, } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 2e8ac78ae..6f780b76f 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,16 +27,14 @@ use vrl::{ }; use crate::{ - execution::client_request_details::ClientRequestDetails, - executors::{ + execution::client_request_details::ClientRequestDetails, executors::{ common::{ - SubgraphExecutionRequest, HttpExecutionResponse, SubgraphExecutor, SubgraphExecutorBoxedArc, + HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient}, - }, - response::graphql_error::GraphQLError, + }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError }; type SubgraphName = String; @@ -60,10 +58,14 @@ pub struct SubgraphExecutorMap { semaphores_by_origin: DashMap>, max_connections_per_host: usize, in_flight_requests: Arc>, ABuildHasher>>, + plugins: Arc>>, } impl SubgraphExecutorMap { - pub fn new(config: Arc) -> Self { + pub fn new( + config: Arc, + plugins: Arc>>, + ) -> Self { let https = HttpsConnector::new(); let client: HttpClient = Client::builder(TokioExecutor::new()) .pool_timer(TokioTimer::new()) @@ -85,14 +87,16 @@ impl SubgraphExecutorMap { semaphores_by_origin: Default::default(), max_connections_per_host, in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())), + plugins, } } pub fn from_http_endpoint_map( subgraph_endpoint_map: HashMap, config: Arc, + plugins: Arc>>, ) -> Result { - let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone()); + let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), plugins); for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.into_iter() { let endpoint_str = config @@ -121,8 +125,40 @@ impl SubgraphExecutorMap { execution_request: SubgraphExecutionRequest<'a>, client_request: &ClientRequestDetails<'a, 'req>, ) -> HttpExecutionResponse { - match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request).await, + let mut start_payload = OnSubgraphExecuteStartPayload { + subgraph_name: subgraph_name.to_string(), + execution_request, + execution_result: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in self.plugins.as_ref() { + let result = plugin.on_subgraph_execute(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + }; + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + let execution_request = start_payload.execution_request; + + let execution_result = match self.get_or_create_executor(subgraph_name, client_request) { + Ok(Some(executor)) => executor + .execute(execution_request) + .await, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", @@ -137,7 +173,33 @@ impl SubgraphExecutorMap { ); self.internal_server_error_response("Internal server error".into(), subgraph_name) } + }; + + let mut end_payload = OnSubgraphExecuteEndPayload { + execution_result + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(response) => { + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + }; + } + ControlFlowResult::OnEnd(_) => { + unreachable!("End callbacks should not register further end callbacks"); + } + } } + + end_payload.execution_result } fn internal_server_error_response( @@ -324,6 +386,7 @@ impl SubgraphExecutorMap { semaphore, self.config.clone(), self.in_flight_requests.clone(), + self.plugins.clone(), ); self.executors_by_subgraph diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index 7d6ac9256..d5400d314 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -2,7 +2,7 @@ use dashmap::DashMap; use sonic_rs::{JsonContainerTrait, JsonValueTrait}; use crate::{ - hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, }; @@ -11,13 +11,13 @@ pub struct APQPlugin { } impl RouterPlugin for APQPlugin { - fn on_deserialization<'exec>( + fn on_graphql_params<'exec>( &'exec self, - start_payload: OnDeserializationStartPayload<'exec>, - ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { - start_payload.on_end(|mut end_payload| { - let persisted_query_ext = end_payload.graphql_params.extensions.as_ref() + payload.on_end(|mut payload| { + let persisted_query_ext = payload.graphql_params.extensions.as_ref() .and_then(|ext| ext.get("persistedQuery")) .and_then(|pq| pq.as_object()); if let Some(persisted_query_ext) = persisted_query_ext { @@ -25,31 +25,31 @@ impl RouterPlugin for APQPlugin { Some("1") => {} _ => { // TODO: Error for unsupported version - return end_payload.cont(); + return payload.cont(); } } let sha256_hash = match persisted_query_ext.get(&"sha256Hash").and_then(|h| h.as_str()) { Some(h) => h, None => { - return end_payload.cont(); + return payload.cont(); } }; - if let Some(query_param) = &end_payload.graphql_params.query { + if let Some(query_param) = &payload.graphql_params.query { // Store the query in the cache self.cache.insert(sha256_hash.to_string(), query_param.to_string()); } else { // Try to get the query from the cache if let Some(cached_query) = self.cache.get(sha256_hash) { // Update the graphql_params with the cached query - end_payload.graphql_params.query = Some(cached_query.value().to_string()); + payload.graphql_params.query = Some(cached_query.value().to_string()); } else { // Error - return end_payload.cont(); + return payload.cont(); } } } - end_payload.cont() + payload.cont() }) } } diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index 5942e6d91..d9d611307 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -3,7 +3,13 @@ use http::HeaderMap; use redis::Commands; use crate::{ - execution::plan::PlanExecutionOutput, hooks::{on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_schema_reload::OnSchemaReloadPayload}, plugin_trait::{EndPayload, HookResult, StartPayload}, plugins::plugin_trait::RouterPlugin, utils::consts::TYPENAME_FIELD_NAME + execution::plan::PlanExecutionOutput, + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{EndPayload, HookResult, StartPayload}, + plugins::plugin_trait::RouterPlugin, + utils::consts::TYPENAME_FIELD_NAME, }; pub struct ResponseCachePlugin { @@ -33,15 +39,12 @@ impl RouterPlugin for ResponseCachePlugin { if let Ok(mut conn) = self.redis_client.get_connection() { let cached_response: Option> = conn.get(&key).ok(); if let Some(cached_response) = cached_response { - return payload.end_response( - - PlanExecutionOutput { - body: cached_response, - headers: HeaderMap::new(), - } - ); + return payload.end_response(PlanExecutionOutput { + body: cached_response, + headers: HeaderMap::new(), + }); } - return payload.on_end(move |payload: OnExecuteEndPayload<'exec>| { + return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { // Do not cache if there are errors if !payload.errors.is_empty() { return payload.cont(); @@ -73,6 +76,7 @@ impl RouterPlugin for ResponseCachePlugin { // Insert the ttl into extensions for client awareness payload .extensions + .get_or_insert_default() .insert("response_cache_ttl".to_string(), sonic_rs::json!(max_ttl)); // Set the cache with the decided ttl @@ -83,11 +87,10 @@ impl RouterPlugin for ResponseCachePlugin { } payload.cont() } - fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { + fn on_supergraph_reload<'a>(&'a self, payload: OnSupergraphLoadStartPayload) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { // Visit the schema and update ttl_per_type based on some directive payload - .new_schema - .document + .new_ast .definitions .iter() .for_each(|def| { @@ -110,5 +113,7 @@ impl RouterPlugin for ResponseCachePlugin { } } }); + + payload.cont() } } diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 55d98a893..037a314d0 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,16 +1,16 @@ use dashmap::DashMap; -use crate::{executors::dedupe::SharedResponse, hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; +use crate::{executors::dedupe::SharedResponse, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; pub struct SubgraphResponseCachePlugin { cache: DashMap, } impl RouterPlugin for SubgraphResponseCachePlugin { - fn on_subgraph_http_request<'exec>( - &'static self, - payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { + fn on_subgraph_execute<'exec>( + &'exec self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -21,9 +21,9 @@ impl RouterPlugin for SubgraphResponseCachePlugin { *payload.response = Some(cached_response.clone()); return payload.cont(); } - payload.on_end(move |payload: OnSubgraphHttpResponsePayload<'exec>| { + payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { // Here payload.response is not Option - self.cache.insert(key, payload.response.clone()); + self.cache.insert(key, payload.execution_result.body.as_ref()); payload.cont() }) } diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 65ccf6f4d..453c84c98 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,8 +1,8 @@ pub mod on_execute; -pub mod on_schema_reload; +pub mod on_supergraph_load; pub mod on_subgraph_http_request; pub mod on_http_request; -pub mod on_deserialization; +pub mod on_graphql_params; pub mod on_graphql_parse; pub mod on_graphql_validation; pub mod on_query_plan; diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 5057075e3..dfcdaceb8 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -11,31 +11,23 @@ pub struct OnExecuteStartPayload<'exec> { pub router_http_request: &'exec HttpRequest, pub query_plan: &'exec QueryPlan, - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut HashMap, - - pub skip_execution: &'exec mut bool, + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, pub variable_values: &'exec Option>, - pub dedupe_subgraph_requests: &'exec mut bool, + pub dedupe_subgraph_requests: bool, } impl<'exec> StartPayload> for OnExecuteStartPayload<'exec> {} pub struct OnExecuteEndPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: &'exec QueryPlan, - - - pub data: &'exec mut Value<'exec>, - pub errors: &'exec mut Vec, - pub extensions: &'exec mut HashMap, - - pub variable_values: &'exec Option>, + pub data: Value<'exec>, + pub errors: Vec, + pub extensions: Option>, - pub dedupe_subgraph_requests: &'exec mut bool, + pub response_size_estimate: usize, } impl<'exec> EndPayload for OnExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_deserialization.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs similarity index 76% rename from lib/executor/src/plugins/hooks/on_deserialization.rs rename to lib/executor/src/plugins/hooks/on_graphql_params.rs index 84991ff56..5e6ce1c47 100644 --- a/lib/executor/src/plugins/hooks/on_deserialization.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -29,17 +29,16 @@ where Ok(opt.unwrap_or_default()) } -pub struct OnDeserializationStartPayload<'exec> { +pub struct OnGraphQLParamsStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, pub body: Bytes, pub graphql_params: Option, } -impl<'exec> StartPayload> for OnDeserializationStartPayload<'exec> {} +impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} -pub struct OnDeserializationEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, +pub struct OnGraphQLParamsEndPayload { pub graphql_params: GraphQLParams, } -impl<'exec> EndPayload for OnDeserializationEndPayload<'exec> {} +impl EndPayload for OnGraphQLParamsEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index 8719cdac3..162a7eee2 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -1,6 +1,6 @@ use graphql_tools::static_graphql::query::Document; -use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; +use crate::{hooks::on_graphql_params::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; pub struct OnGraphQLParseStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index e5ecf898f..a789cb5fd 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -1,25 +1,71 @@ -use graphql_tools::{static_graphql::query::Document, validation::{utils::ValidationError, validate::ValidationPlan}}; +use graphql_tools::{ + static_graphql::query::Document, + validation::{rules::{ValidationRule, default_rules_validation_plan}, utils::ValidationError, validate::ValidationPlan}, +}; use hive_router_query_planner::state::supergraph_state::SchemaDocument; -use crate::{hooks::on_deserialization::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; +use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnGraphQLValidationStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub graphql_params: &'exec GraphQLParams, pub schema: &'exec SchemaDocument, pub document: &'exec Document, - pub validation_plan: &'exec mut ValidationPlan, - pub errors: &'exec mut Option> + default_validation_plan: &'exec ValidationPlan, + new_validation_plan: Option, + pub errors: Option>, } -impl<'exec> StartPayload> for OnGraphQLValidationStartPayload<'exec> {} +impl<'exec> StartPayload> + for OnGraphQLValidationStartPayload<'exec> +{ +} + +impl<'exec> OnGraphQLValidationStartPayload<'exec> { + pub fn new( + router_http_request: &'exec ntex::web::HttpRequest, + schema: &'exec SchemaDocument, + document: &'exec Document, + default_validation_plan: &'exec ValidationPlan, + ) -> Self { + OnGraphQLValidationStartPayload { + router_http_request, + schema, + document, + default_validation_plan, + new_validation_plan: None, + errors: None, + } + } + + pub fn add_validation_rule(&mut self, rule: Box) { + self.new_validation_plan + .get_or_insert_with(|| default_rules_validation_plan()) + .add_rule(rule); + } + + pub fn filter_validation_rules(&mut self, mut f: F) + where + F: FnMut(&Box) -> bool, + { + let plan = self + .new_validation_plan + .get_or_insert_with(|| default_rules_validation_plan()); + plan.rules.retain(|rule| f(rule)); + } + + pub fn get_validation_plan(&self) -> &ValidationPlan { + match &self.new_validation_plan { + Some(plan) => plan, + None => self.default_validation_plan, + } + } +} pub struct OnGraphQLValidationEndPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub graphql_params: &'exec GraphQLParams, pub schema: &'exec SchemaDocument, pub document: &'exec Document, - pub errors: &'exec mut Vec, + pub errors: Vec, } -impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} \ No newline at end of file +impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 847e7465e..29a8344e5 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -3,7 +3,7 @@ use ntex::{http::Response, web::HttpRequest}; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnHttpRequestPayload<'exec> { - pub router_http_request: &'exec HttpRequest, + pub client_request: &'exec HttpRequest, } impl<'exec> StartPayload> for OnHttpRequestPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index 7963524ad..39ae3c2d6 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -1,13 +1,13 @@ -use graphql_tools::static_graphql::query::Document; -use hive_router_query_planner::planner::{Planner, plan_nodes::QueryPlan}; +use hive_router_query_planner::{ast::operation::OperationDefinition, graph::PlannerOverrideContext, planner::{Planner, plan_nodes::QueryPlan}, utils::cancellation::CancellationToken}; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnQueryPlanStartPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub document: &'exec Document, - // Other params - pub query_plan: &'exec mut Option, + pub filtered_operation_for_plan: &'exec OperationDefinition, + pub planner_override_context: PlannerOverrideContext, + pub cancellation_token: &'exec CancellationToken, + pub query_plan: Option, pub planner: &'exec Planner, } @@ -15,9 +15,11 @@ impl<'exec> StartPayload> for OnQueryPlanStartPaylo pub struct OnQueryPlanEndPayload<'exec> { pub router_http_request: &'exec ntex::web::HttpRequest, - pub document: &'exec Document, - // Other params - pub query_plan: &'exec mut QueryPlan, + pub filtered_operation_for_plan: &'exec OperationDefinition, + pub planner_override_context: PlannerOverrideContext, + pub cancellation_token: &'exec CancellationToken, + pub query_plan: QueryPlan, + pub planner: &'exec Planner, } impl<'exec> EndPayload for OnQueryPlanEndPayload<'exec> {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_schema_reload.rs b/lib/executor/src/plugins/hooks/on_schema_reload.rs deleted file mode 100644 index a96d6c240..000000000 --- a/lib/executor/src/plugins/hooks/on_schema_reload.rs +++ /dev/null @@ -1,6 +0,0 @@ -use hive_router_query_planner::consumer_schema::ConsumerSchema; - -pub struct OnSchemaReloadPayload<'a> { - pub old_schema: &'a ConsumerSchema, - pub new_schema: &'a mut ConsumerSchema, -} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 167340bc8..6a514006f 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,34 +1,19 @@ -use bytes::Bytes; -use hive_router_query_planner::planner::plan_nodes::FetchNode; -use crate::{executors::common::{SubgraphExecutionRequest, SubgraphExecutorBoxedArc}, plugin_trait::{EndPayload, StartPayload}, response::subgraph_response::SubgraphResponse}; + +use crate::{executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, plugin_trait::{EndPayload, StartPayload}}; pub struct OnSubgraphExecuteStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub executor: &'exec SubgraphExecutorBoxedArc, - pub subgraph_name: &'exec str, + pub subgraph_name: String, - pub node: &'exec mut FetchNode, - pub execution_request: &'exec mut SubgraphExecutionRequest<'exec>, - pub response: &'exec mut Option>, + pub execution_request: SubgraphExecutionRequest<'exec>, + pub execution_result: Option, } -impl<'exec> StartPayload> for OnSubgraphExecuteStartPayload<'exec> {} - -pub enum SubgraphExecutorResponse<'exec> { - Bytes(Bytes), - SubgraphResponse(SubgraphResponse<'exec>), -} - -pub struct OnSubgraphExecuteEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub executor: &'exec SubgraphExecutorBoxedArc, - pub subgraph_name: &'exec str, +impl<'exec> StartPayload for OnSubgraphExecuteStartPayload<'exec> {} - pub node: &'exec FetchNode, - pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - pub response: &'exec mut SubgraphExecutorResponse<'exec>, +pub struct OnSubgraphExecuteEndPayload { + pub execution_result: HttpExecutionResponse, } -impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} \ No newline at end of file +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload {} \ No newline at end of file diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index ac720b870..44bfa9dc9 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,4 +1,6 @@ -use http::{HeaderMap, Uri}; +use bytes::Bytes; +use http::{HeaderMap, Request, Uri}; +use http_body_util::Full; use ntex::web::HttpRequest; use crate::{ @@ -6,28 +8,18 @@ use crate::{ ; pub struct OnSubgraphHttpRequestPayload<'exec> { - pub router_http_request: &'exec HttpRequest, pub subgraph_name: &'exec str, // At this point, there is no point of mutating this - pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - - pub endpoint: &'exec mut Uri, - // By default, it is POST - pub method: &'exec mut http::Method, - pub headers: &'exec mut HeaderMap, - pub request_body: &'exec mut Vec, + pub request: Request>, // Early response - pub response: &'exec mut Option, + pub response: Option, } -impl<'exec> StartPayload> for OnSubgraphHttpRequestPayload<'exec> {} +impl<'exec> StartPayload for OnSubgraphHttpRequestPayload<'exec> {} -pub struct OnSubgraphHttpResponsePayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub subgraph_name: &'exec str, - pub execution_request: &'exec SubgraphExecutionRequest<'exec>, - pub response: &'exec mut SharedResponse, +pub struct OnSubgraphHttpResponsePayload { + pub response: SharedResponse, } -impl<'exec> EndPayload for OnSubgraphHttpResponsePayload<'exec> {} +impl<'exec> EndPayload for OnSubgraphHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs new file mode 100644 index 000000000..e4e68ca35 --- /dev/null +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -0,0 +1,27 @@ +use std::sync::Arc; + +use graphql_tools::static_graphql::schema::Document; +use hive_router_query_planner::{planner::Planner}; +use arc_swap::{ArcSwap}; + +use crate::{SubgraphExecutorMap, introspection::schema::SchemaMetadata, plugin_trait::{EndPayload, StartPayload}}; + + +pub struct SupergraphData { + pub metadata: SchemaMetadata, + pub planner: Planner, + pub subgraph_executor_map: SubgraphExecutorMap, +} + +pub struct OnSupergraphLoadStartPayload { + pub current_supergraph_data: Arc>>, + pub new_ast: Document, +} + +impl StartPayload for OnSupergraphLoadStartPayload {} + +pub struct OnSupergraphLoadEndPayload { + pub new_supergraph_data: SupergraphData, +} + +impl EndPayload for OnSupergraphLoadEndPayload {} \ No newline at end of file diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 220c5b88c..d56856652 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -1,113 +1,141 @@ use crate::execution::plan::PlanExecutionOutput; -use crate::hooks::on_deserialization::{OnDeserializationEndPayload, OnDeserializationStartPayload}; use crate::hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}; +use crate::hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}; use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; -use crate::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use crate::hooks::on_graphql_validation::{ + OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +}; use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponse}; use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; -use crate::hooks::on_schema_reload::OnSchemaReloadPayload; -use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; +use crate::hooks::on_subgraph_execute::{ + OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, +}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, +}; +use crate::hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}; pub struct HookResult<'exec, TStartPayload, TEndPayload> { - pub start_payload: TStartPayload, + pub payload: TStartPayload, pub control_flow: ControlFlowResult<'exec, TEndPayload>, } pub enum ControlFlowResult<'exec, TEndPayload> { Continue, EndResponse(PlanExecutionOutput), - OnEnd(Box HookResult<'exec, TEndPayload, ()> + 'exec>), + OnEnd(Box HookResult<'exec, TEndPayload, ()> + Send + 'exec>), } pub trait StartPayload - where Self: Sized - { - +where + Self: Sized, +{ fn cont<'exec>(self) -> HookResult<'exec, Self, TEndPayload> { HookResult { - start_payload: self, + payload: self, control_flow: ControlFlowResult::Continue, } } - fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, TEndPayload> { + fn end_response<'exec>( + self, + output: PlanExecutionOutput, + ) -> HookResult<'exec, Self, TEndPayload> { HookResult { - start_payload: self, + payload: self, control_flow: ControlFlowResult::EndResponse(output), } } fn on_end<'exec, F>(self, f: F) -> HookResult<'exec, Self, TEndPayload> - where F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + 'exec, + where + F: FnOnce(TEndPayload) -> HookResult<'exec, TEndPayload, ()> + Send + 'exec, { HookResult { - start_payload: self, + payload: self, control_flow: ControlFlowResult::OnEnd(Box::new(f)), } } } pub trait EndPayload - where Self: Sized - { - fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { - HookResult { - start_payload: self, - control_flow: ControlFlowResult::Continue, - } +where + Self: Sized, +{ + fn cont<'exec>(self) -> HookResult<'exec, Self, ()> { + HookResult { + payload: self, + control_flow: ControlFlowResult::Continue, } + } - fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { - HookResult { - start_payload: self, - control_flow: ControlFlowResult::EndResponse(output), - } + fn end_response<'exec>(self, output: PlanExecutionOutput) -> HookResult<'exec, Self, ()> { + HookResult { + payload: self, + control_flow: ControlFlowResult::EndResponse(output), } + } } -// Add sync send etc pub trait RouterPlugin { fn on_http_request<'exec>( - &self, + &self, start_payload: OnHttpRequestPayload<'exec>, ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponse<'exec>> { start_payload.cont() } - fn on_deserialization<'exec>( - &'exec self, - start_payload: OnDeserializationStartPayload<'exec>, - ) -> HookResult<'exec, OnDeserializationStartPayload<'exec>, OnDeserializationEndPayload<'exec>> { + fn on_graphql_params<'exec>( + &'exec self, + start_payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { start_payload.cont() } fn on_graphql_parse<'exec>( - &self, + &self, start_payload: OnGraphQLParseStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload<'exec>> { start_payload.cont() } fn on_graphql_validation<'exec>( - &self, + &self, start_payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload<'exec>> { + ) -> HookResult< + 'exec, + OnGraphQLValidationStartPayload<'exec>, + OnGraphQLValidationEndPayload<'exec>, + > { start_payload.cont() } fn on_query_plan<'exec>( - &self, + &self, start_payload: OnQueryPlanStartPayload<'exec>, - ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { start_payload.cont() } fn on_execute<'exec>( - &'exec self, + &'exec self, start_payload: OnExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { + start_payload.cont() + } + fn on_subgraph_execute<'exec>( + &'exec self, + start_payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> + { start_payload.cont() } fn on_subgraph_http_request<'exec>( - &'static self, + &'exec self, start_payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload<'exec>> { + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> + { + start_payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { start_payload.cont() } - fn on_schema_reload<'a>(&'a self, _start_payload: OnSchemaReloadPayload) {} -} \ No newline at end of file +} From 134dda5cd6c93989f360906102c0a6a18fdf48a4 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Tue, 18 Nov 2025 22:14:26 +0300 Subject: [PATCH 07/15] More --- lib/executor/src/executors/common.rs | 1 + lib/executor/src/executors/http.rs | 145 +++++++++++++----- .../examples/subgraph_response_cache.rs | 10 +- .../plugins/hooks/on_subgraph_http_request.rs | 5 +- 4 files changed, 111 insertions(+), 50 deletions(-) diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index ba13b8707..9044c062e 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -45,6 +45,7 @@ impl SubgraphExecutionRequest<'_> { } } +#[derive(Clone)] pub struct HttpExecutionResponse { pub body: Bytes, pub headers: HeaderMap, diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index c09e01067..1e00e6fce 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,7 +2,8 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; -use crate::plugin_trait::RouterPlugin; +use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; +use crate::plugin_trait::{ControlFlowResult, RouterPlugin}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; use tokio::sync::OnceCell; @@ -10,7 +11,7 @@ use tokio::sync::OnceCell; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; -use http::HeaderMap; +use http::{HeaderMap, StatusCode}; use http::HeaderValue; use http_body_util::BodyExt; use http_body_util::Full; @@ -136,31 +137,87 @@ impl HTTPSubgraphExecutor { Ok(body) } - async fn _send_request( - &self, - body: Vec, - headers: HeaderMap, - ) -> Result { + fn error_to_graphql_bytes(&self, error: SubgraphExecutorError) -> Bytes { + let graphql_error: GraphQLError = error.into(); + let mut graphql_error = graphql_error.add_subgraph_name(&self.subgraph_name); + graphql_error.message = "Failed to execute request to subgraph".to_string(); + + let errors = vec![graphql_error]; + // This unwrap is safe as GraphQLError serialization shouldn't fail. + let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); + let mut buffer = BytesMut::new(); + buffer.put_slice(b"{\"errors\":"); + buffer.put_slice(&errors_bytes); + buffer.put_slice(b"}"); + buffer.freeze() + } + + fn log_error(&self, error: &SubgraphExecutorError) { + tracing::error!( + error = error as &dyn std::error::Error, + "Subgraph executor error" + ); + } +} + +async fn send_request( + http_client: &Client, Full>, + subgraph_name: &str, + endpoint: &http::Uri, + method: http::Method, + body: Vec, + headers: HeaderMap, + plugins: Arc>>, +) -> Result { let mut req = hyper::Request::builder() - .method(http::Method::POST) - .uri(&self.endpoint) + .method(method) + .uri(endpoint) .version(Version::HTTP_11) .body(Full::new(Bytes::from(body))) .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(self.endpoint.to_string(), e.to_string()) + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) })?; *req.headers_mut() = headers; - debug!("making http request to {}", self.endpoint.to_string()); + let mut start_payload = OnSubgraphHttpRequestPayload { + subgraph_name, + request: req, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + // TODO: Fixx + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + } + } + + debug!("making http request to {}", endpoint.to_string()); - let res = self.http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) + let req = start_payload.request; + + let res = http_client.request(req).await.map_err(|e| { + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) })?; debug!( "http request to {} completed, status: {}", - self.endpoint.to_string(), + endpoint.to_string(), res.status() ); @@ -169,45 +226,47 @@ impl HTTPSubgraphExecutor { .collect() .await .map_err(|e| { - SubgraphExecutorError::RequestFailure(self.endpoint.to_string(), e.to_string()) + SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) })? .to_bytes(); if body.is_empty() { return Err(SubgraphExecutorError::RequestFailure( - self.endpoint.to_string(), + endpoint.to_string(), "Empty response body".to_string(), )); } - Ok(SharedResponse { + let response = SharedResponse { status: parts.status, - body, + body: body, headers: parts.headers, - }) - } + }; - fn error_to_graphql_bytes(&self, error: SubgraphExecutorError) -> Bytes { - let graphql_error: GraphQLError = error.into(); - let mut graphql_error = graphql_error.add_subgraph_name(&self.subgraph_name); - graphql_error.message = "Failed to execute request to subgraph".to_string(); + let mut end_payload = OnSubgraphHttpResponsePayload { + response, + }; - let errors = vec![graphql_error]; - // This unwrap is safe as GraphQLError serialization shouldn't fail. - let errors_bytes = sonic_rs::to_vec(&errors).unwrap(); - let mut buffer = BytesMut::new(); - buffer.put_slice(b"{\"errors\":"); - buffer.put_slice(&errors_bytes); - buffer.put_slice(b"}"); - buffer.freeze() - } + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(response) => { + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } - fn log_error(&self, error: &SubgraphExecutorError) { - tracing::error!( - error = error as &dyn std::error::Error, - "Subgraph executor error" - ); - } + Ok(end_payload.response) } #[async_trait] @@ -233,11 +292,13 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { headers.insert(key, value.clone()); }); + let method = http::Method::POST; + if !self.config.traffic_shaping.dedupe_enabled || !execution_request.dedupe { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - return match self._send_request(body, headers).await { + return match send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, @@ -252,7 +313,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { }; } - let fingerprint = request_fingerprint(&http::Method::POST, &self.endpoint, &headers, &body); + let fingerprint = request_fingerprint(&method, &self.endpoint, &headers, &body); // Clone the cell from the map, dropping the lock from the DashMap immediately. // Prevents any deadlocks. @@ -269,7 +330,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - self._send_request(body, headers).await + send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await }; // It's important to remove the entry from the map before returning the result. // This ensures that once the OnceCell is set, no future requests can join it. diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 037a314d0..71ff8c1d9 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,15 +1,15 @@ use dashmap::DashMap; -use crate::{executors::dedupe::SharedResponse, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; +use crate::{executors::{common::HttpExecutionResponse}, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; pub struct SubgraphResponseCachePlugin { - cache: DashMap, + cache: DashMap, } impl RouterPlugin for SubgraphResponseCachePlugin { fn on_subgraph_execute<'exec>( &'exec self, - payload: OnSubgraphExecuteStartPayload<'exec>, + mut payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { let key = format!( "subgraph_response_cache:{}:{:?}", @@ -18,12 +18,12 @@ impl RouterPlugin for SubgraphResponseCachePlugin { if let Some(cached_response) = self.cache.get(&key) { // Here payload.response is Option // So it is bypassing the actual subgraph request - *payload.response = Some(cached_response.clone()); + payload.execution_result = Some(cached_response.clone()); return payload.cont(); } payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { // Here payload.response is not Option - self.cache.insert(key, payload.execution_result.body.as_ref()); + self.cache.insert(key, payload.execution_result.clone()); payload.cont() }) } diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index 44bfa9dc9..f7c798ed7 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,10 +1,9 @@ use bytes::Bytes; -use http::{HeaderMap, Request, Uri}; +use http::{Request}; use http_body_util::Full; -use ntex::web::HttpRequest; use crate::{ - executors::{common::SubgraphExecutionRequest, dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} + executors::{dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} ; pub struct OnSubgraphHttpRequestPayload<'exec> { From 9791e1ceb96da52581cfeb89a202856152e76a2a Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Wed, 19 Nov 2025 18:56:36 +0300 Subject: [PATCH 08/15] Owned values --- bin/router/src/jwt/mod.rs | 21 +- bin/router/src/lib.rs | 53 ++-- bin/router/src/pipeline/coerce_variables.rs | 8 +- bin/router/src/pipeline/csrf_prevention.rs | 8 +- .../pipeline/deserialize_graphql_params.rs | 33 +-- bin/router/src/pipeline/error.rs | 33 +-- bin/router/src/pipeline/execution.rs | 35 +-- bin/router/src/pipeline/header.rs | 18 +- bin/router/src/pipeline/mod.rs | 219 +++++++++------ bin/router/src/pipeline/normalize.rs | 12 +- bin/router/src/pipeline/parser.rs | 57 ++-- .../src/pipeline/progressive_override.rs | 8 +- bin/router/src/pipeline/query_plan.rs | 45 ++- bin/router/src/pipeline/validation.rs | 41 ++- bin/router/src/schema_state.rs | 37 ++- .../src/execution/client_request_details.rs | 28 +- lib/executor/src/execution/error.rs | 2 +- lib/executor/src/execution/jwt_forward.rs | 2 +- lib/executor/src/execution/plan.rs | 257 ++++++++++-------- lib/executor/src/executors/http.rs | 196 +++++++------ lib/executor/src/executors/map.rs | 39 +-- lib/executor/src/headers/expression.rs | 4 +- lib/executor/src/headers/mod.rs | 96 +++---- lib/executor/src/headers/request.rs | 6 +- lib/executor/src/headers/response.rs | 4 +- lib/executor/src/lib.rs | 1 - lib/executor/src/plugins/examples/apq.rs | 18 +- lib/executor/src/plugins/examples/mod.rs | 2 +- .../src/plugins/examples/multipart.rs | 0 .../src/plugins/examples/response_cache.rs | 38 +-- .../examples/subgraph_response_cache.rs | 16 +- lib/executor/src/plugins/hooks/mod.rs | 8 +- lib/executor/src/plugins/hooks/on_execute.rs | 7 +- .../src/plugins/hooks/on_graphql_params.rs | 92 ++++++- .../src/plugins/hooks/on_graphql_parse.rs | 15 +- .../plugins/hooks/on_graphql_validation.rs | 28 +- .../src/plugins/hooks/on_http_request.rs | 2 +- .../src/plugins/hooks/on_query_plan.rs | 20 +- .../src/plugins/hooks/on_subgraph_execute.rs | 12 +- .../plugins/hooks/on_subgraph_http_request.rs | 11 +- .../src/plugins/hooks/on_supergraph_load.rs | 13 +- lib/executor/src/plugins/mod.rs | 2 +- lib/executor/src/plugins/plugin_trait.rs | 21 +- 43 files changed, 868 insertions(+), 700 deletions(-) create mode 100644 lib/executor/src/plugins/examples/multipart.rs diff --git a/bin/router/src/jwt/mod.rs b/bin/router/src/jwt/mod.rs index d95854b1c..d7804e156 100644 --- a/bin/router/src/jwt/mod.rs +++ b/bin/router/src/jwt/mod.rs @@ -265,26 +265,27 @@ impl JwtAuthRuntime { Ok(token_data) } - pub fn validate_request(&self, request: &mut HttpRequest) -> Result<(), JwtError> { + pub fn validate_request( + &self, + request: &HttpRequest, + ) -> Result, JwtError> { let valid_jwks = self.jwks.all(); match self.authenticate(&valid_jwks, request) { - Ok((token_payload, maybe_token_prefix, token)) => { - request.extensions_mut().insert(JwtRequestContext { - token_payload, - token_raw: token, - token_prefix: maybe_token_prefix, - }); - } + Ok((token_payload, maybe_token_prefix, token)) => Ok(Some(JwtRequestContext { + token_payload, + token_raw: token, + token_prefix: maybe_token_prefix, + })), Err(e) => { warn!("jwt token error: {:?}", e); if self.config.require_authentication.is_some_and(|v| v) { return Err(e); } + + Ok(None) } } - - Ok(()) } } diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index fb19df5c8..416021b3a 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -19,7 +19,11 @@ use crate::{ }, jwt::JwtAuthRuntime, logger::configure_logging, - pipeline::graphql_request_handler, + pipeline::{ + error::PipelineError, + graphql_request_handler, + header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + }, }; pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; @@ -27,12 +31,13 @@ pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; use hive_router_config::{load_config, HiveRouterConfig}; use http::header::RETRY_AFTER; use ntex::{ - util::Bytes, web::{self, HttpRequest} + util::Bytes, + web::{self, HttpRequest}, }; use tracing::{info, warn}; async fn graphql_endpoint_handler( - mut request: HttpRequest, + req: HttpRequest, body_bytes: Bytes, schema_state: web::types::State>, app_state: web::types::State>, @@ -44,26 +49,35 @@ async fn graphql_endpoint_handler( if let Some(early_response) = app_state .cors_runtime .as_ref() - .and_then(|cors| cors.get_early_response(&request)) + .and_then(|cors| cors.get_early_response(&req)) { return early_response; } - let mut res = graphql_request_handler( - &mut request, + let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); + + let result = match graphql_request_handler( + req, body_bytes, supergraph, - app_state.get_ref(), - schema_state.get_ref(), + app_state.get_ref().clone(), + schema_state.get_ref().clone(), ) - .await; + .await + { + Ok(response_with_req) => response_with_req, + Err(error) => return PipelineError { accept_ok, error }.into(), + }; + + let mut response = result.result; + let req = result.request; // Apply CORS headers to the final response if CORS is configured. if let Some(cors) = app_state.cors_runtime.as_ref() { - cors.set_headers(&request, res.headers_mut()); + cors.set_headers(&req, response.headers_mut()); } - res + response } else { warn!("No supergraph available yet, unable to process request"); @@ -111,18 +125,23 @@ pub async fn configure_app_from_config( }; let router_config_arc = Arc::new(router_config); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc.clone(), jwt_runtime)?); - let schema_state = - SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone(), shared_state.clone()).await?; + let shared_state = Arc::new(RouterSharedState::new( + router_config_arc.clone(), + jwt_runtime, + )?); + let schema_state = SchemaState::new_from_config( + bg_tasks_manager, + router_config_arc.clone(), + shared_state.clone(), + ) + .await?; let schema_state_arc = Arc::new(schema_state); Ok((shared_state, schema_state_arc)) } pub fn configure_ntex_app(cfg: &mut web::ServiceConfig) { - cfg - .route("/graphql", web::to(graphql_endpoint_handler)) + cfg.route("/graphql", web::to(graphql_endpoint_handler)) .route("/health", web::to(health_check_handler)) .route("/readiness", web::to(readiness_check_handler)); } - diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index b159f244e..d10fbb6c4 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -10,7 +10,7 @@ use ntex::web::HttpRequest; use sonic_rs::Value; use tracing::{error, trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; #[derive(Clone, Debug)] @@ -24,14 +24,14 @@ pub fn coerce_request_variables( supergraph: &SupergraphData, graphql_params: &mut GraphQLParams, normalized_operation: &Arc, -) -> Result { +) -> Result { if req.method() == Method::GET { if let Some(OperationKind::Mutation) = normalized_operation.operation_for_plan.operation_kind { error!("Mutation is not allowed over GET, stopping"); - return Err(req.new_pipeline_error(PipelineErrorVariant::MutationNotAllowedOverHttpGet)); + return Err(PipelineErrorVariant::MutationNotAllowedOverHttpGet); } } @@ -55,7 +55,7 @@ pub fn coerce_request_variables( "failed to collect variables from incoming request: {}", err_msg ); - Err(req.new_pipeline_error(PipelineErrorVariant::VariablesCoercionError(err_msg))) + Err(PipelineErrorVariant::VariablesCoercionError(err_msg)) } } } diff --git a/bin/router/src/pipeline/csrf_prevention.rs b/bin/router/src/pipeline/csrf_prevention.rs index 51561dd99..37c063b09 100644 --- a/bin/router/src/pipeline/csrf_prevention.rs +++ b/bin/router/src/pipeline/csrf_prevention.rs @@ -1,7 +1,7 @@ use hive_router_config::csrf::CSRFPreventionConfig; use ntex::web::HttpRequest; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; // NON_PREFLIGHTED_CONTENT_TYPES are content types that do not require a preflight // OPTIONS request. These are content types that are considered "simple" by the CORS @@ -15,9 +15,9 @@ const NON_PREFLIGHTED_CONTENT_TYPES: [&str; 3] = [ #[inline] pub fn perform_csrf_prevention( - req: &mut HttpRequest, + req: &HttpRequest, csrf_config: &CSRFPreventionConfig, -) -> Result<(), PipelineError> { +) -> Result<(), PipelineErrorVariant> { // If CSRF prevention is not configured or disabled, skip the checks. if !csrf_config.enabled || csrf_config.required_headers.is_empty() { return Ok(()); @@ -39,7 +39,7 @@ pub fn perform_csrf_prevention( if has_required_header { Ok(()) } else { - Err(req.new_pipeline_error(PipelineErrorVariant::CsrfPreventionFailed)) + Err(PipelineErrorVariant::CsrfPreventionFailed) } } diff --git a/bin/router/src/pipeline/deserialize_graphql_params.rs b/bin/router/src/pipeline/deserialize_graphql_params.rs index 3c0eb5f12..b22b18a3a 100644 --- a/bin/router/src/pipeline/deserialize_graphql_params.rs +++ b/bin/router/src/pipeline/deserialize_graphql_params.rs @@ -7,7 +7,7 @@ use ntex::web::types::Query; use ntex::web::HttpRequest; use tracing::{trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::header::AssertRequestJson; #[derive(serde::Deserialize, Debug)] @@ -55,11 +55,11 @@ impl TryInto for GETQueryParams { } pub trait GetQueryStr { - fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant>; + fn get_query(&self) -> Result<&str, PipelineErrorVariant>; } impl GetQueryStr for GraphQLParams { - fn get_query<'a>(&'a self) -> Result<&'a str, PipelineErrorVariant> { + fn get_query(&self) -> Result<&str, PipelineErrorVariant> { self.query .as_deref() .ok_or(PipelineErrorVariant::GetMissingQueryParam("query")) @@ -70,25 +70,22 @@ impl GetQueryStr for GraphQLParams { pub fn deserialize_graphql_params( req: &HttpRequest, body_bytes: Bytes, -) -> Result { +) -> Result { let http_method = req.method(); let graphql_params: GraphQLParams = match *http_method { Method::GET => { trace!("processing GET GraphQL operation"); - let query_params_str = req.uri().query().ok_or_else(|| { - req.new_pipeline_error(PipelineErrorVariant::GetInvalidQueryParams) - })?; + let query_params_str = req + .uri() + .query() + .ok_or_else(|| PipelineErrorVariant::GetInvalidQueryParams)?; let query_params = Query::::from_query(query_params_str) - .map_err(|e| { - req.new_pipeline_error(PipelineErrorVariant::GetUnprocessableQueryParams(e)) - })? + .map_err(PipelineErrorVariant::GetUnprocessableQueryParams)? .0; trace!("parsed GET query params: {:?}", query_params); - query_params - .try_into() - .map_err(|err| req.new_pipeline_error(err))? + query_params.try_into()? } Method::POST => { trace!("Processing POST GraphQL request"); @@ -98,7 +95,7 @@ pub fn deserialize_graphql_params( let execution_request = unsafe { sonic_rs::from_slice_unchecked::(&body_bytes).map_err(|e| { warn!("Failed to parse body: {}", e); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseBody(e)) + PipelineErrorVariant::FailedToParseBody(e) })? }; @@ -107,11 +104,9 @@ pub fn deserialize_graphql_params( _ => { warn!("unsupported HTTP method: {}", http_method); - return Err( - req.new_pipeline_error(PipelineErrorVariant::UnsupportedHttpMethod( - http_method.to_owned(), - )), - ); + return Err(PipelineErrorVariant::UnsupportedHttpMethod( + http_method.to_owned(), + )); } }; diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index 71e0a197d..1d856d62e 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -10,15 +10,12 @@ use hive_router_query_planner::{ }; use http::{HeaderName, Method, StatusCode}; use ntex::{ - http::ResponseBuilder, - web::{self, error::QueryPayloadError, HttpRequest}, + http::{Response, ResponseBuilder}, + web::error::QueryPayloadError, }; use serde::{Deserialize, Serialize}; -use crate::pipeline::{ - header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, - progressive_override::LabelEvaluationError, -}; +use crate::pipeline::progressive_override::LabelEvaluationError; #[derive(Debug)] pub struct PipelineError { @@ -26,18 +23,6 @@ pub struct PipelineError { pub error: PipelineErrorVariant, } -pub trait PipelineErrorFromAcceptHeader { - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError; -} - -impl PipelineErrorFromAcceptHeader for HttpRequest { - #[inline] - fn new_pipeline_error(&self, error: PipelineErrorVariant) -> PipelineError { - let accept_ok = !self.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); - PipelineError { accept_ok, error } - } -} - #[derive(Debug, thiserror::Error)] pub enum PipelineErrorVariant { // HTTP-related errors @@ -156,11 +141,11 @@ pub struct FailedExecutionResult { pub errors: Option>, } -impl PipelineError { - pub fn into_response(self) -> web::HttpResponse { - let status = self.error.default_status_code(self.accept_ok); +impl From for Response { + fn from(val: PipelineError) -> Self { + let status = val.error.default_status_code(val.accept_ok); - if let PipelineErrorVariant::ValidationErrors(validation_errors) = self.error { + if let PipelineErrorVariant::ValidationErrors(validation_errors) = val.error { let validation_error_result = FailedExecutionResult { errors: Some(validation_errors.iter().map(|error| error.into()).collect()), }; @@ -168,8 +153,8 @@ impl PipelineError { return ResponseBuilder::new(status).json(&validation_error_result); } - let code = self.error.graphql_error_code(); - let message = self.error.graphql_error_message(); + let code = val.error.graphql_error_code(); + let message = val.error.graphql_error_message(); let graphql_error = GraphQLError::from_message_and_extensions( message, diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index 56f92fece..e69a89179 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -2,13 +2,14 @@ use std::collections::HashMap; use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::shared_state::RouterSharedState; -use hive_router_plan_executor::execute_query_plan; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; -use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; +use hive_router_plan_executor::execution::plan::{ + PlanExecutionOutput, QueryPlanExecutionContext, ResultWithRequest, +}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -26,14 +27,14 @@ enum ExposeQueryPlanMode { #[inline] pub async fn execute_plan( - req: &HttpRequest, + req: HttpRequest, supergraph: &SupergraphData, - app_state: &Arc, - normalized_payload: &Arc, - query_plan_payload: &Arc, + app_state: Arc, + normalized_payload: Arc, + query_plan_payload: Arc, variable_payload: &CoerceVariablesPayload, - client_request_details: &ClientRequestDetails<'_, '_>, -) -> Result { + client_request_details: &ClientRequestDetails<'_>, +) -> Result, PipelineErrorVariant> { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { @@ -65,7 +66,7 @@ pub async fn execute_plan( metadata: &supergraph.metadata, }; - let jwt_forward_plan: Option = if app_state + let jwt_auth_forwarding: Option = if app_state .router_config .jwt .is_jwt_extensions_forwarding_enabled() @@ -79,12 +80,12 @@ pub async fn execute_plan( .forward_claims_to_upstream_extensions .field_name, ) - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))? + .map_err(PipelineErrorVariant::JwtForwardingError)? } else { None }; - execute_query_plan(QueryPlanExecutionContext { + let ctx = QueryPlanExecutionContext { router_http_request: req, query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, @@ -94,13 +95,13 @@ pub async fn execute_plan( client_request: client_request_details, introspection_context: &introspection_context, operation_type_name: normalized_payload.root_type_name, - jwt_auth_forwarding: &jwt_forward_plan, + jwt_auth_forwarding, executors: &supergraph.subgraph_executor_map, plugins: &app_state.plugins, - }) - .await - .map_err(|err| { + }; + + ctx.execute_query_plan().await.map_err(|err| { tracing::error!("Failed to execute query plan: {}", err); - req.new_pipeline_error(PipelineErrorVariant::PlanExecutionError(err)) + PipelineErrorVariant::PlanExecutionError(err) }) } diff --git a/bin/router/src/pipeline/header.rs b/bin/router/src/pipeline/header.rs index 92a591235..19ea8c7af 100644 --- a/bin/router/src/pipeline/header.rs +++ b/bin/router/src/pipeline/header.rs @@ -6,7 +6,7 @@ use lazy_static::lazy_static; use ntex::web::HttpRequest; use tracing::{trace, warn}; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; lazy_static! { pub static ref APPLICATION_JSON_STR: &'static str = "application/json"; @@ -34,31 +34,29 @@ impl RequestAccepts for HttpRequest { } pub trait AssertRequestJson { - fn assert_json_content_type(&self) -> Result<(), PipelineError>; + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant>; } impl AssertRequestJson for HttpRequest { #[inline] - fn assert_json_content_type(&self) -> Result<(), PipelineError> { + fn assert_json_content_type(&self) -> Result<(), PipelineErrorVariant> { match self.headers().get(CONTENT_TYPE) { Some(value) => { - let content_type_str = value.to_str().map_err(|_| { - self.new_pipeline_error(PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE)) - })?; + let content_type_str = value + .to_str() + .map_err(|_| PipelineErrorVariant::InvalidHeaderValue(CONTENT_TYPE))?; if !content_type_str.contains(*APPLICATION_JSON_STR) { warn!( "Invalid content type on a POST request: {}", content_type_str ); - return Err( - self.new_pipeline_error(PipelineErrorVariant::UnsupportedContentType) - ); + return Err(PipelineErrorVariant::UnsupportedContentType); } Ok(()) } None => { trace!("POST without content type detected"); - Err(self.new_pipeline_error(PipelineErrorVariant::MissingContentTypeHeader)) + Err(PipelineErrorVariant::MissingContentTypeHeader) } } } diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index ddc949c4c..b97f7e67f 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -3,15 +3,16 @@ use std::sync::Arc; use hive_router_plan_executor::{ execution::{ client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::PlanExecutionOutput, + plan::{PlanExecutionOutput, ResultWithRequest, WithResult}, + }, + hooks::{ + on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_supergraph_load::SupergraphData, }, - hooks::{on_graphql_params::{ - OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload - }, on_supergraph_load::SupergraphData}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ - state::supergraph_state::OperationKind, utils::cancellation::CancellationToken + state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, }; use http::{header::CONTENT_TYPE, HeaderValue, Method}; use ntex::{ @@ -22,20 +23,31 @@ use ntex::{ use crate::{ jwt::context::JwtRequestContext, pipeline::{ - coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, deserialize_graphql_params::{GetQueryStr, deserialize_graphql_params}, error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, header::{ - APPLICATION_GRAPHQL_RESPONSE_JSON, APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, RequestAccepts, TEXT_HTML_CONTENT_TYPE - }, normalize::normalize_request_with_cache, parser::{ParseResult, parse_operation_with_cache}, progressive_override::request_override_context, query_plan::{QueryPlanResult, plan_operation_with_cache}, validation::validate_operation_with_cache + coerce_variables::coerce_request_variables, + csrf_prevention::perform_csrf_prevention, + deserialize_graphql_params::{deserialize_graphql_params, GetQueryStr}, + error::PipelineErrorVariant, + execution::execute_plan, + header::{ + RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON, + APPLICATION_GRAPHQL_RESPONSE_JSON_STR, APPLICATION_JSON, TEXT_HTML_CONTENT_TYPE, + }, + normalize::normalize_request_with_cache, + parser::{parse_operation_with_cache, ParseResult}, + progressive_override::request_override_context, + query_plan::{plan_operation_with_cache, QueryPlanResult}, + validation::validate_operation_with_cache, }, - schema_state::{SchemaState}, + schema_state::SchemaState, shared_state::RouterSharedState, }; pub mod coerce_variables; pub mod cors; pub mod csrf_prevention; +pub mod deserialize_graphql_params; pub mod error; pub mod execution; -pub mod deserialize_graphql_params; pub mod header; pub mod normalize; pub mod parser; @@ -47,70 +59,82 @@ static GRAPHIQL_HTML: &str = include_str!("../../static/graphiql.html"); #[inline] pub async fn graphql_request_handler( - req: &mut HttpRequest, + req: HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, - shared_state: &Arc, - schema_state: &Arc, -) -> web::HttpResponse { + shared_state: Arc, + schema_state: Arc, +) -> Result, PipelineErrorVariant> { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { - return web::HttpResponse::Ok() - .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) - .body(GRAPHIQL_HTML); + return Ok(req.with_result( + web::HttpResponse::Ok() + .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) + .body(GRAPHIQL_HTML), + )); } else { - return web::HttpResponse::NotFound().into(); + return Ok(req.with_result(web::HttpResponse::NotFound().into())); } } - if let Some(jwt) = &shared_state.jwt_auth_runtime { - match jwt.validate_request(req) { - Ok(_) => (), - Err(err) => return err.make_response(), + let jwt_context = if let Some(jwt) = &shared_state.jwt_auth_runtime { + match jwt.validate_request(&req) { + Ok(jwt_context) => jwt_context, + Err(err) => return Ok(req.with_result(err.make_response())), } - } + } else { + None + }; - match execute_pipeline(req, body_bytes, supergraph, shared_state, schema_state).await { - Ok(response) => { - let response_bytes = Bytes::from(response.body); - let response_headers = response.headers; - - let response_content_type: &'static HeaderValue = - if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { - &APPLICATION_GRAPHQL_RESPONSE_JSON - } else { - &APPLICATION_JSON - }; - - let mut response_builder = web::HttpResponse::Ok(); - for (header_name, header_value) in response_headers { - if let Some(header_name) = header_name { - response_builder.header(header_name, header_value); - } - } + let response_content_type: &'static HeaderValue = + if req.accepts_content_type(*APPLICATION_GRAPHQL_RESPONSE_JSON_STR) { + &APPLICATION_GRAPHQL_RESPONSE_JSON + } else { + &APPLICATION_JSON + }; - response_builder - .header(http::header::CONTENT_TYPE, response_content_type) - .body(response_bytes) + let execution_result_with_req = execute_pipeline( + req, + body_bytes, + supergraph, + shared_state, + schema_state, + jwt_context, + ) + .await?; + let response = execution_result_with_req.result; + let response_bytes = Bytes::from(response.body); + let response_headers = response.headers; + + let mut response_builder = web::HttpResponse::Ok(); + for (header_name, header_value) in response_headers { + if let Some(header_name) = header_name { + response_builder.header(header_name, header_value); } - Err(err) => err.into_response(), } + + Ok(execution_result_with_req.request.with_result( + response_builder + .header(http::header::CONTENT_TYPE, response_content_type) + .body(response_bytes), + )) } #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline<'req>( - req: &'req mut HttpRequest, +pub async fn execute_pipeline( + req: HttpRequest, body: Bytes, supergraph: &SupergraphData, - shared_state: &'req Arc, - schema_state: &Arc, -) -> Result { - perform_csrf_prevention(req, &shared_state.router_config.csrf)?; + shared_state: Arc, + schema_state: Arc, + jwt_context: Option, +) -> Result, PipelineErrorVariant> { + perform_csrf_prevention(&req, &shared_state.router_config.csrf)?; /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; - let mut deserialization_payload: OnGraphQLParamsStartPayload<'req> = OnGraphQLParamsStartPayload { + let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { router_http_request: req, body, graphql_params: None, @@ -121,7 +145,9 @@ pub async fn execute_pipeline<'req>( match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(response); + return Ok(deserialization_payload + .router_http_request + .with_result(response)); } ControlFlowResult::OnEnd(callback) => { deserialization_end_callbacks.push(callback); @@ -129,20 +155,24 @@ pub async fn execute_pipeline<'req>( } } let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { - deserialize_graphql_params(req, deserialization_payload.body).expect("Failed to parse execution request") + deserialize_graphql_params( + &deserialization_payload.router_http_request, + deserialization_payload.body, + ) + .expect("Failed to parse execution request") }); - let mut payload = OnGraphQLParamsEndPayload { - graphql_params, - }; + let mut payload = OnGraphQLParamsEndPayload { graphql_params }; for deserialization_end_callback in deserialization_end_callbacks { let result = deserialization_end_callback(payload); payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(response); - }, + return Ok(deserialization_payload + .router_http_request + .with_result(response)); + } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again unreachable!("on_end callback returned OnEnd again"); @@ -152,49 +182,58 @@ pub async fn execute_pipeline<'req>( let mut graphql_params = payload.graphql_params; /* Handle on_deserialize hook in the plugins - END */ - let parser_payload = match parse_operation_with_cache(req, shared_state, &graphql_params).await? { + let req = deserialization_payload.router_http_request; + let parser_result = + parse_operation_with_cache(req, shared_state.clone(), &graphql_params).await?; + + let mut req = parser_result.request; + + let parser_payload = match parser_result.result { ParseResult::Payload(payload) => payload, ParseResult::Response(response) => { - return Ok(response); + return Ok(req.with_result(response)); } }; - validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) - .await?; + validate_operation_with_cache( + &mut req, + supergraph, + schema_state.clone(), + shared_state.clone(), + &parser_payload, + ) + .await?; let normalize_payload = normalize_request_with_cache( - req, supergraph, - schema_state, + schema_state.clone(), &graphql_params, &parser_payload, ) .await?; - + let variable_payload = - coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; + coerce_request_variables(&req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); - let req_extensions = req.extensions(); - let jwt_context = req_extensions.get::(); let jwt_request_details = match jwt_context { Some(jwt_context) => JwtRequestDetails::Authenticated { - token: jwt_context.token_raw.as_str(), - prefix: jwt_context.token_prefix.as_deref(), scopes: jwt_context.extract_scopes(), - claims: &jwt_context + claims: jwt_context .get_claims_value() - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?, + .map_err(PipelineErrorVariant::JwtForwardingError)?, + token: jwt_context.token_raw, + prefix: jwt_context.token_prefix, }, None => JwtRequestDetails::Unauthenticated, }; let client_request_details = ClientRequestDetails { - method: req.method(), - url: req.uri(), - headers: req.headers(), + method: req.method().clone(), + url: req.uri().clone(), + headers: req.headers().clone(), operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -203,39 +242,41 @@ pub async fn execute_pipeline<'req>( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: graphql_params.get_query().map_err(|err| req.new_pipeline_error(err))?, + query: graphql_params.get_query()?, }, - jwt: &jwt_request_details, + jwt: jwt_request_details, }; let progressive_override_ctx = request_override_context( &shared_state.override_labels_evaluator, &client_request_details, ) - .map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?; + .map_err(PipelineErrorVariant::LabelEvaluationError)?; - let query_plan_payload = match plan_operation_with_cache( + let query_plan_result = plan_operation_with_cache( req, supergraph, - schema_state, - &normalize_payload, + schema_state.clone(), + normalize_payload.clone(), &progressive_override_ctx, &query_plan_cancellation_token, - shared_state, + shared_state.clone(), ) - .await? { - QueryPlanResult::QueryPlan(query_plan_payload) => query_plan_payload, + .await?; + let req = query_plan_result.request; + let query_plan_payload = match query_plan_result.result { + QueryPlanResult::QueryPlan(plan) => plan, QueryPlanResult::Response(response) => { - return Ok(response); + return Ok(req.with_result(response)); } }; let execution_result = execute_plan( req, supergraph, - shared_state, - &normalize_payload, - &query_plan_payload, + shared_state.clone(), + normalize_payload.clone(), + query_plan_payload, &variable_payload, &client_request_details, ) diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index c57e2d566..54093d065 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -7,12 +7,11 @@ use hive_router_plan_executor::introspection::partition::partition_operation; use hive_router_plan_executor::projection::plan::FieldProjectionPlan; use hive_router_query_planner::ast::normalization::normalize_operation; use hive_router_query_planner::ast::operation::OperationDefinition; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState}; +use crate::schema_state::SchemaState; use tracing::{error, trace}; #[derive(Debug)] @@ -26,12 +25,11 @@ pub struct GraphQLNormalizationPayload { #[inline] pub async fn normalize_request_with_cache( - req: &HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, + schema_state: Arc, graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, -) -> Result, PipelineError> { +) -> Result, PipelineErrorVariant> { let cache_key = match &graphql_params.operation_name { Some(operation_name) => { let mut hasher = Xxh3::new(); @@ -87,7 +85,7 @@ pub async fn normalize_request_with_cache( error!("Failed to normalize GraphQL operation: {}", err); trace!("{:?}", err); - Err(req.new_pipeline_error(PipelineErrorVariant::NormalizationError(err))) + Err(PipelineErrorVariant::NormalizationError(err)) } }, } diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 18365a3e2..e6194fbd5 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,16 +2,20 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; +use hive_router_plan_executor::execution::plan::{ + PlanExecutionOutput, ResultWithRequest, WithResult, +}; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; -use hive_router_plan_executor::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseStartPayload}; +use hive_router_plan_executor::hooks::on_graphql_parse::{ + OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, +}; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; use crate::pipeline::deserialize_graphql_params::GetQueryStr; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::shared_state::RouterSharedState; use tracing::{error, trace}; @@ -28,26 +32,26 @@ pub enum ParseResult { #[inline] pub async fn parse_operation_with_cache( - req: &HttpRequest, - app_state: &Arc, + req: HttpRequest, + app_state: Arc, graphql_params: &GraphQLParams, -) -> Result { +) -> Result, PipelineErrorVariant> { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); hasher.finish() }; + /* Handle on_graphql_parse hook in the plugins - START */ + let mut start_payload = OnGraphQLParseStartPayload { + router_http_request: req, + graphql_params, + document: None, + }; let parsed_operation = if let Some(cached) = app_state.parse_cache.get(&cache_key).await { trace!("Found cached parsed operation for query"); cached } else { - /* Handle on_graphql_parse hook in the plugins - START */ - let mut start_payload = OnGraphQLParseStartPayload { - router_http_request: req, - graphql_params, - document: None, - }; let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { let result = plugin.on_graphql_parse(start_payload); @@ -57,7 +61,9 @@ pub async fn parse_operation_with_cache( // continue to next plugin } ControlFlowResult::EndResponse(response) => { - return Ok(ParseResult::Response(response)); + return Ok(start_payload + .router_http_request + .with_result(ParseResult::Response(response))); } ControlFlowResult::OnEnd(callback) => { // store the callback to be called later @@ -65,25 +71,20 @@ pub async fn parse_operation_with_cache( } } } + let document = match start_payload.document { Some(parsed) => parsed, None => { - let query_str = graphql_params.get_query().map_err(|err| { - req.new_pipeline_error(err) - })?; + let query_str = graphql_params.get_query()?; let parsed = safe_parse_operation(query_str).map_err(|err| { error!("Failed to parse GraphQL operation: {}", err); - req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) + PipelineErrorVariant::FailedToParseOperation(err) })?; trace!("successfully parsed GraphQL operation"); parsed } }; - let mut end_payload = OnGraphQLParseEndPayload { - router_http_request: req, - graphql_params, - document, - }; + let mut end_payload = OnGraphQLParseEndPayload { document }; for callback in on_end_callbacks { let result = callback(end_payload); end_payload = result.payload; @@ -92,7 +93,9 @@ pub async fn parse_operation_with_cache( // continue to next callback } ControlFlowResult::EndResponse(response) => { - return Ok(ParseResult::Response(response)); + return Ok(start_payload + .router_http_request + .with_result(ParseResult::Response(response))); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -111,10 +114,10 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok( - ParseResult::Payload(GraphQLParserPayload { + Ok(start_payload + .router_http_request + .with_result(ParseResult::Payload(GraphQLParserPayload { parsed_operation, cache_key, - }) - ) + }))) } diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..dc28a0d60 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec, 'req>( +pub fn request_override_context<'exec>( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec, 'req>, + client_request_details: &ClientRequestDetails<'exec>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; @@ -158,9 +158,9 @@ impl OverrideLabelsEvaluator { }) } - pub(crate) fn evaluate<'exec, 'req>( + pub(crate) fn evaluate<'exec>( &self, - client_request: &ClientRequestDetails<'exec, 'req>, + client_request: &ClientRequestDetails<'exec>, ) -> Result, LabelEvaluationError> { let mut active_flags = self.static_enabled_labels.clone(); diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 58b4e7475..807946296 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -1,13 +1,17 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; -use crate::schema_state::{SchemaState}; +use crate::schema_state::SchemaState; use crate::RouterSharedState; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; -use hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanStartPayload; +use hive_router_plan_executor::execution::plan::{ + PlanExecutionOutput, ResultWithRequest, WithResult, +}; +use hive_router_plan_executor::hooks::on_query_plan::{ + OnQueryPlanEndPayload, OnQueryPlanStartPayload, +}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; @@ -28,14 +32,14 @@ pub enum QueryPlanGetterError { #[inline] pub async fn plan_operation_with_cache( - req: &HttpRequest, + mut req: HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - normalized_operation: &Arc, + schema_state: Arc, + normalized_operation: Arc, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, - app_state: &Arc, -) -> Result { + app_state: Arc, +) -> Result, PipelineErrorVariant> { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -47,7 +51,7 @@ pub async fn plan_operation_with_cache( let plan_result = schema_state .plan_cache - .try_get_with(plan_cache_key, async move { + .try_get_with(plan_cache_key, async { if is_pure_introspection { return Ok(Arc::new(QueryPlan { kind: "QueryPlan".to_string(), @@ -57,7 +61,7 @@ pub async fn plan_operation_with_cache( /* Handle on_query_plan hook in the plugins - START */ let mut start_payload = OnQueryPlanStartPayload { - router_http_request: req, + router_http_request: &mut req, filtered_operation_for_plan, planner_override_context: (&request_override_context.clone()).into(), cancellation_token, @@ -90,17 +94,10 @@ pub async fn plan_operation_with_cache( (&request_override_context.clone()).into(), cancellation_token, ) - .map_err(|e| QueryPlanGetterError::Planner(e))?, + .map_err(QueryPlanGetterError::Planner)?, }; - let mut end_payload = hive_router_plan_executor::hooks::on_query_plan::OnQueryPlanEndPayload { - router_http_request: req, - filtered_operation_for_plan, - planner_override_context: (&request_override_context.clone()).into(), - cancellation_token, - query_plan, - planner: &supergraph.planner, - }; + let mut end_payload = OnQueryPlanEndPayload { query_plan }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -124,10 +121,12 @@ pub async fn plan_operation_with_cache( .await; match plan_result { - Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), + Ok(plan) => Ok(req.with_result(QueryPlanResult::QueryPlan(plan))), Err(e) => match e.as_ref() { - QueryPlanGetterError::Planner(e) => Err(req.new_pipeline_error(PipelineErrorVariant::PlannerError(e.clone()))), - QueryPlanGetterError::Response(response) => Ok(QueryPlanResult::Response(response.clone())), + QueryPlanGetterError::Planner(e) => Err(PipelineErrorVariant::PlannerError(e.clone())), + QueryPlanGetterError::Response(response) => { + Ok(req.with_result(QueryPlanResult::Response(response.clone()))) + } }, } } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index f97aa1661..3efbddcca 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -1,12 +1,14 @@ use std::sync::Arc; -use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; +use crate::pipeline::error::PipelineErrorVariant; use crate::pipeline::parser::GraphQLParserPayload; -use crate::schema_state::{SchemaState}; +use crate::schema_state::SchemaState; use crate::shared_state::RouterSharedState; use graphql_tools::validation::validate::validate; use hive_router_plan_executor::execution::plan::PlanExecutionOutput; -use hive_router_plan_executor::hooks::on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}; +use hive_router_plan_executor::hooks::on_graphql_validation::{ + OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, +}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use ntex::web::HttpRequest; @@ -14,12 +16,12 @@ use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( - req: &HttpRequest, + req: &mut HttpRequest, supergraph: &SupergraphData, - schema_state: &Arc, - app_state: &Arc, + schema_state: Arc, + app_state: Arc, parser_payload: &GraphQLParserPayload, -) -> Result, PipelineError> { +) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; let validation_result = match schema_state @@ -40,7 +42,7 @@ pub async fn validate_operation_with_cache( "validation result of hash {} does not exists in cache", parser_payload.cache_key ); - + /* Handle on_graphql_validate hook in the plugins - START */ let mut start_payload = OnGraphQLValidationStartPayload::new( req, @@ -67,21 +69,14 @@ pub async fn validate_operation_with_cache( let errors = match start_payload.errors { Some(errors) => errors, - None => { - validate( - consumer_schema_ast, - &start_payload.document, - start_payload.get_validation_plan(), - ) - } + None => validate( + consumer_schema_ast, + start_payload.document, + start_payload.get_validation_plan(), + ), }; - let mut end_payload = OnGraphQLValidationEndPayload { - router_http_request: req, - schema: consumer_schema_ast, - document: &parser_payload.parsed_operation, - errors, - }; + let mut end_payload = OnGraphQLValidationEndPayload { errors }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -117,9 +112,7 @@ pub async fn validate_operation_with_cache( ); trace!("Validation errors: {:?}", validation_result); - return Err( - req.new_pipeline_error(PipelineErrorVariant::ValidationErrors(validation_result)) - ); + return Err(PipelineErrorVariant::ValidationErrors(validation_result)); } Ok(None) diff --git a/bin/router/src/schema_state.rs b/bin/router/src/schema_state.rs index 3f79a06aa..69db5db3b 100644 --- a/bin/router/src/schema_state.rs +++ b/bin/router/src/schema_state.rs @@ -3,7 +3,13 @@ use async_trait::async_trait; use graphql_tools::{static_graphql::schema::Document, validation::utils::ValidationError}; use hive_router_config::{supergraph::SupergraphSource, HiveRouterConfig}; use hive_router_plan_executor::{ - SubgraphExecutorMap, executors::error::SubgraphExecutorError, hooks::on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData}, introspection::schema::SchemaWithMetadata, plugin_trait::{ControlFlowResult, RouterPlugin} + executors::error::SubgraphExecutorError, + hooks::on_supergraph_load::{ + OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload, SupergraphData, + }, + introspection::schema::SchemaWithMetadata, + plugin_trait::{ControlFlowResult, RouterPlugin}, + SubgraphExecutorMap, }; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::{ @@ -18,10 +24,13 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, trace}; use crate::{ - RouterSharedState, background_tasks::{BackgroundTask, BackgroundTasksManager}, pipeline::normalize::GraphQLNormalizationPayload, supergraph::{ + background_tasks::{BackgroundTask, BackgroundTasksManager}, + pipeline::normalize::GraphQLNormalizationPayload, + supergraph::{ base::{LoadSupergraphError, ReloadSupergraphResult, SupergraphLoader}, resolve_from_config, - } + }, + RouterSharedState, }; pub struct SchemaState { @@ -55,7 +64,7 @@ impl SchemaState { pub async fn new_from_config( bg_tasks_manager: &mut BackgroundTasksManager, router_config: Arc, - app_state: Arc + app_state: Arc, ) -> Result { let (tx, mut rx) = mpsc::channel::(1); let background_loader = SupergraphBackgroundLoader::new(&router_config.supergraph, tx)?; @@ -91,10 +100,10 @@ impl SchemaState { match result.control_flow { ControlFlowResult::Continue => { // continue to next plugin - }, + } ControlFlowResult::EndResponse(_) => { unreachable!("Plugins should not end supergraph reload processing"); - }, + } ControlFlowResult::OnEnd(callback) => { on_end_callbacks.push(callback); } @@ -115,12 +124,16 @@ impl SchemaState { match result.control_flow { ControlFlowResult::Continue => { // continue to next callback - }, + } ControlFlowResult::EndResponse(_) => { - unreachable!("Plugins should not end supergraph reload processing"); - }, + unreachable!( + "Plugins should not end supergraph reload processing" + ); + } ControlFlowResult::OnEnd(_) => { - unreachable!("End callbacks should not register further end callbacks"); + unreachable!( + "End callbacks should not register further end callbacks" + ); } } } @@ -155,8 +168,8 @@ impl SchemaState { parsed_supergraph_sdl: &Document, plugins: Arc>>, ) -> Result { - let supergraph_state = SupergraphState::new(&parsed_supergraph_sdl); - let planner = Planner::new_from_supergraph(&parsed_supergraph_sdl)?; + let supergraph_state = SupergraphState::new(parsed_supergraph_sdl); + let planner = Planner::new_from_supergraph(parsed_supergraph_sdl)?; let metadata = planner.consumer_schema.schema_metadata(); let subgraph_executor_map = SubgraphExecutorMap::from_http_endpoint_map( supergraph_state.subgraph_endpoint_map, diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 35540dab2..71f28b746 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap}; +use std::collections::BTreeMap; use bytes::Bytes; use http::Method; @@ -13,28 +13,28 @@ pub struct OperationDetails<'exec> { pub kind: &'static str, } -pub struct ClientRequestDetails<'exec, 'req> { - pub method: &'req Method, - pub url: &'req http::Uri, - pub headers: &'req NtexHeaderMap, +pub struct ClientRequestDetails<'exec> { + pub method: Method, + pub url: http::Uri, + pub headers: NtexHeaderMap, pub operation: OperationDetails<'exec>, - pub jwt: &'exec JwtRequestDetails<'req>, + pub jwt: JwtRequestDetails, } -pub enum JwtRequestDetails<'exec> { +pub enum JwtRequestDetails { Authenticated { - token: &'exec str, - prefix: Option<&'exec str>, - claims: &'exec sonic_rs::Value, + token: String, + prefix: Option, + claims: sonic_rs::Value, scopes: Option>, }, Unauthenticated, } -impl From<&ClientRequestDetails<'_, '_>> for Value { +impl From<&ClientRequestDetails<'_>> for Value { fn from(details: &ClientRequestDetails) -> Self { // .request.headers - let headers_value = client_header_map_to_vrl_value(details.headers); + let headers_value = client_header_map_to_vrl_value(&details.headers); // .request.url let url_value = Self::Object(BTreeMap::from([ @@ -67,7 +67,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ])); // .request.jwt - let jwt_value = match details.jwt { + let jwt_value = match &details.jwt { JwtRequestDetails::Authenticated { token, prefix, @@ -78,7 +78,7 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ("token".into(), token.to_string().into()), ( "prefix".into(), - prefix.unwrap_or_default().to_string().into(), + prefix.as_deref().unwrap_or_default().to_string().into(), ), ("claims".into(), sonic_value_to_vrl_value(claims)), ( diff --git a/lib/executor/src/execution/error.rs b/lib/executor/src/execution/error.rs index aaaf9a729..63460eb48 100644 --- a/lib/executor/src/execution/error.rs +++ b/lib/executor/src/execution/error.rs @@ -116,7 +116,7 @@ impl IntoPlanExecutionError for Result { let kind = PlanExecutionErrorKind::ProjectionFailure(source); PlanExecutionError::new(kind, context) }) - } + } } impl IntoPlanExecutionError for Result { diff --git a/lib/executor/src/execution/jwt_forward.rs b/lib/executor/src/execution/jwt_forward.rs index 24c19ff9f..9aefc601c 100644 --- a/lib/executor/src/execution/jwt_forward.rs +++ b/lib/executor/src/execution/jwt_forward.rs @@ -8,7 +8,7 @@ pub struct JwtAuthForwardingPlan { pub extension_field_value: Value, } -impl JwtRequestDetails<'_> { +impl JwtRequestDetails { pub fn build_forwarding_plan( &self, extension_field_name: &str, diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 520f429c8..5ea8bd758 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -1,4 +1,7 @@ -use std::collections::{BTreeSet, HashMap}; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; @@ -12,49 +15,58 @@ use serde::Deserialize; use sonic_rs::ValueRef; use crate::{ - context::ExecutionContext, execution::{ + context::ExecutionContext, + execution::{ client_request_details::ClientRequestDetails, error::{IntoPlanExecutionError, LazyPlanContext, PlanExecutionError}, jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt, - }, executors::{ + }, + executors::{ common::{HttpExecutionResponse, SubgraphExecutionRequest}, map::SubgraphExecutorMap, - }, headers::{ + }, + headers::{ plan::HeaderRulesPlan, request::modify_subgraph_request_headers, response::{apply_subgraph_response_headers, modify_client_response_headers}, - }, hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, introspection::{ - resolve::{IntrospectionContext, resolve_introspection}, + }, + hooks::on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + introspection::{ + resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, - }, plugin_trait::{ControlFlowResult, RouterPlugin}, projection::{ + }, + plugin_trait::{ControlFlowResult, RouterPlugin}, + projection::{ plan::FieldProjectionPlan, - request::{RequestProjectionContext, project_requires}, + request::{project_requires, RequestProjectionContext}, response::project_by_operation, - }, response::{ + }, + response::{ graphql_error::{GraphQLError, GraphQLErrorExtensions, GraphQLErrorPath}, merge::deep_merge, subgraph_response::SubgraphResponse, value::Value, - }, utils::{ + }, + utils::{ consts::{CLOSE_BRACKET, OPEN_BRACKET}, traverse::{traverse_and_callback, traverse_and_callback_mut}, - } + }, }; -pub struct QueryPlanExecutionContext<'exec, 'req> { - pub router_http_request: &'exec HttpRequest, +pub struct QueryPlanExecutionContext<'exec> { + pub router_http_request: HttpRequest, pub plugins: &'exec Vec>, - pub query_plan: &'exec QueryPlan, + pub query_plan: Arc, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, pub extensions: Option>, - pub client_request: &'exec ClientRequestDetails<'exec, 'req>, + pub client_request: &'exec ClientRequestDetails<'exec>, pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, pub executors: &'exec SubgraphExecutorMap, - pub jwt_auth_forwarding: &'exec Option, + pub jwt_auth_forwarding: Option, } #[derive(Clone)] @@ -63,119 +75,140 @@ pub struct PlanExecutionOutput { pub headers: HeaderMap, } -pub async fn execute_query_plan<'exec, 'req>( - ctx: QueryPlanExecutionContext<'exec, 'req>, -) -> Result { - let init_value = if let Some(introspection_query) = ctx.introspection_context.query { - resolve_introspection(introspection_query, ctx.introspection_context) - } else { - Value::Null - }; +pub struct ResultWithRequest { + pub result: T, + pub request: HttpRequest, +} - let dedupe_subgraph_requests = ctx.operation_type_name == "Query"; +pub trait WithResult { + fn with_result(self, result: T) -> ResultWithRequest; +} - let mut start_payload = OnExecuteStartPayload { - router_http_request: ctx.router_http_request, - query_plan: ctx.query_plan, - data: init_value, - errors: Vec::new(), - extensions: ctx.extensions.clone(), - variable_values: ctx.variable_values, - dedupe_subgraph_requests, - }; +impl WithResult for HttpRequest { + fn with_result(self, result: T) -> ResultWithRequest { + ResultWithRequest { result, request: self } + } +} - let mut on_end_callbacks = vec![]; +impl<'exec> QueryPlanExecutionContext<'exec> { + pub async fn execute_query_plan( + self, + ) -> Result, PlanExecutionError> { + let init_value = if let Some(introspection_query) = self.introspection_context.query { + resolve_introspection(introspection_query, self.introspection_context) + } else { + Value::Null + }; - for plugin in ctx.plugins { - let result = plugin.on_execute(start_payload); - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ }, - ControlFlowResult::EndResponse(response) => { - return Ok(response); - }, - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); + let dedupe_subgraph_requests = self.operation_type_name == "Query"; + + let mut start_payload = OnExecuteStartPayload { + router_http_request: self.router_http_request, + query_plan: self.query_plan, + data: init_value, + errors: Vec::new(), + extensions: self.extensions.clone(), + variable_values: self.variable_values, + dedupe_subgraph_requests, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in self.plugins { + let result = plugin.on_execute(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + return Ok(start_payload.router_http_request.with_result(response)); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } } } - } - let init_value = start_payload.data; - - let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value); - let executor = Executor::new( - ctx.variable_values, - ctx.executors, - ctx.introspection_context.metadata, - ctx.client_request, - ctx.headers_plan, - ctx.jwt_auth_forwarding, - // Deduplicate subgraph requests only if the operation type is a query - ctx.operation_type_name == "Query", - ); - - if ctx.query_plan.node.is_some() { - executor - .execute(&mut exec_ctx, ctx.query_plan.node.as_ref()) - .await?; - } + let query_plan = start_payload.query_plan; + + let init_value = start_payload.data; - let mut response_headers = HeaderMap::new(); - modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + let mut exec_ctx = ExecutionContext::new(&query_plan, init_value); + let executor = Executor::new( + self.variable_values, + self.executors, + self.introspection_context.metadata, + self.client_request, + self.headers_plan, + self.jwt_auth_forwarding, + // Deduplicate subgraph requests only if the operation type is a query + self.operation_type_name == "Query", + ); + + if query_plan.node.is_some() { + executor + .execute(&mut exec_ctx, query_plan.node.as_ref()) + .await?; + } + + let mut response_headers = HeaderMap::new(); + modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers) + .with_plan_context(LazyPlanContext { + subgraph_name: || None, + affected_path: || None, + })?; + + let mut end_payload = OnExecuteEndPayload { + data: exec_ctx.final_response, + errors: exec_ctx.errors, + extensions: start_payload.extensions, + response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), + }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(output) => { + return Ok(start_payload.router_http_request.with_result(output)); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); + } + } + } + + let body = project_by_operation( + &end_payload.data, + end_payload.errors, + &self.extensions, + self.operation_type_name, + self.projection_plan, + self.variable_values, + end_payload.response_size_estimate, + ) .with_plan_context(LazyPlanContext { subgraph_name: || None, affected_path: || None, })?; - let mut end_payload = OnExecuteEndPayload { - data: exec_ctx.final_response, - errors: exec_ctx.errors, - extensions: start_payload.extensions, - response_size_estimate: exec_ctx.response_storage.estimate_final_response_size(), - }; - - for callback in on_end_callbacks { - let result = callback(end_payload); - end_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ }, - ControlFlowResult::EndResponse(response) => { - return Ok(response); - }, - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } - } + Ok(start_payload + .router_http_request + .with_result(PlanExecutionOutput { + body, + headers: response_headers, + })) } - - let body = project_by_operation( - &end_payload.data, - end_payload.errors, - &ctx.extensions, - ctx.operation_type_name, - ctx.projection_plan, - ctx.variable_values, - end_payload.response_size_estimate, - ) - .with_plan_context(LazyPlanContext { - subgraph_name: || None, - affected_path: || None, - })?; - - Ok(PlanExecutionOutput { - body, - headers: response_headers, - }) } -pub struct Executor<'exec, 'req> { +pub struct Executor<'exec> { variable_values: &'exec Option>, schema_metadata: &'exec SchemaMetadata, executors: &'exec SubgraphExecutorMap, - client_request: &'exec ClientRequestDetails<'exec, 'req>, + client_request: &'exec ClientRequestDetails<'exec>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, } @@ -263,14 +296,14 @@ struct PreparedFlattenData { representation_hash_to_index: HashMap, } -impl<'exec, 'req> Executor<'exec, 'req> { +impl<'exec> Executor<'exec> { pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, - client_request: &'exec ClientRequestDetails<'exec, 'req>, + client_request: &'exec ClientRequestDetails<'exec>, headers_plan: &'exec HeaderRulesPlan, - jwt_forwarding_plan: &'exec Option, + jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, ) -> Self { Executor { diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 1e00e6fce..a5a5ed263 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -2,7 +2,9 @@ use std::sync::Arc; use crate::executors::common::HttpExecutionResponse; use crate::executors::dedupe::{request_fingerprint, ABuildHasher, SharedResponse}; -use crate::hooks::on_subgraph_http_request::{OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload}; +use crate::hooks::on_subgraph_http_request::{ + OnSubgraphHttpRequestPayload, OnSubgraphHttpResponsePayload, +}; use crate::plugin_trait::{ControlFlowResult, RouterPlugin}; use dashmap::DashMap; use hive_router_config::HiveRouterConfig; @@ -11,8 +13,8 @@ use tokio::sync::OnceCell; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; -use http::{HeaderMap, StatusCode}; use http::HeaderValue; +use http::{HeaderMap, StatusCode}; use http_body_util::BodyExt; use http_body_util::Full; use hyper::Version; @@ -169,104 +171,101 @@ async fn send_request( headers: HeaderMap, plugins: Arc>>, ) -> Result { - let mut req = hyper::Request::builder() - .method(method) - .uri(endpoint) - .version(Version::HTTP_11) - .body(Full::new(Bytes::from(body))) - .map_err(|e| { - SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) - })?; - - *req.headers_mut() = headers; - - let mut start_payload = OnSubgraphHttpRequestPayload { - subgraph_name, - request: req, - response: None, - }; + let mut req = hyper::Request::builder() + .method(method) + .uri(endpoint) + .version(Version::HTTP_11) + .body(Full::new(Bytes::from(body))) + .map_err(|e| { + SubgraphExecutorError::RequestBuildFailure(endpoint.to_string(), e.to_string()) + })?; - let mut on_end_callbacks = vec![]; - - for plugin in plugins.as_ref() { - let result = plugin.on_subgraph_http_request(start_payload); - start_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next plugin */ } - ControlFlowResult::EndResponse(response) => { - // TODO: Fixx - return Ok(SharedResponse { - status: StatusCode::OK, - body: response.body.into(), - headers: response.headers, - }); - } - ControlFlowResult::OnEnd(callback) => { - on_end_callbacks.push(callback); - } + *req.headers_mut() = headers; + + let mut start_payload = OnSubgraphHttpRequestPayload { + subgraph_name, + request: req, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.as_ref() { + let result = plugin.on_subgraph_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next plugin */ } + ControlFlowResult::EndResponse(response) => { + // TODO: Fixx + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); } } + } - debug!("making http request to {}", endpoint.to_string()); + debug!("making http request to {}", endpoint.to_string()); - let req = start_payload.request; + let req = start_payload.request; - let res = http_client.request(req).await.map_err(|e| { - SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) - })?; + let res = http_client + .request(req) + .await + .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))?; - debug!( - "http request to {} completed, status: {}", - endpoint.to_string(), - res.status() - ); + debug!( + "http request to {} completed, status: {}", + endpoint.to_string(), + res.status() + ); - let (parts, body) = res.into_parts(); - let body = body - .collect() - .await - .map_err(|e| { - SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()) - })? - .to_bytes(); - - if body.is_empty() { - return Err(SubgraphExecutorError::RequestFailure( - endpoint.to_string(), - "Empty response body".to_string(), - )); - } - - let response = SharedResponse { - status: parts.status, - body: body, - headers: parts.headers, - }; + let (parts, body) = res.into_parts(); + let body = body + .collect() + .await + .map_err(|e| SubgraphExecutorError::RequestFailure(endpoint.to_string(), e.to_string()))? + .to_bytes(); - let mut end_payload = OnSubgraphHttpResponsePayload { - response, - }; + if body.is_empty() { + return Err(SubgraphExecutorError::RequestFailure( + endpoint.to_string(), + "Empty response body".to_string(), + )); + } - for callback in on_end_callbacks { - let result = callback(end_payload); - end_payload = result.payload; - match result.control_flow { - ControlFlowResult::Continue => { /* continue to next callback */ } - ControlFlowResult::EndResponse(response) => { - return Ok(SharedResponse { - status: StatusCode::OK, - body: response.body.into(), - headers: response.headers, - }); - } - ControlFlowResult::OnEnd(_) => { - // on_end callbacks should not return OnEnd again - unreachable!("on_end callback returned OnEnd again"); - } + let response = SharedResponse { + status: parts.status, + body, + headers: parts.headers, + }; + + let mut end_payload = OnSubgraphHttpResponsePayload { response }; + + for callback in on_end_callbacks { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { /* continue to next callback */ } + ControlFlowResult::EndResponse(response) => { + return Ok(SharedResponse { + status: StatusCode::OK, + body: response.body.into(), + headers: response.headers, + }); + } + ControlFlowResult::OnEnd(_) => { + // on_end callbacks should not return OnEnd again + unreachable!("on_end callback returned OnEnd again"); } } + } - Ok(end_payload.response) + Ok(end_payload.response) } #[async_trait] @@ -298,7 +297,17 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - return match send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await { + return match send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + headers, + self.plugins.clone(), + ) + .await + { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, @@ -330,7 +339,16 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { // This unwrap is safe because the semaphore is never closed during the application's lifecycle. // `acquire()` only fails if the semaphore is closed, so this will always return `Ok`. let _permit = self.semaphore.acquire().await.unwrap(); - send_request(&self.http_client, &self.subgraph_name, &self.endpoint, method, body, headers, self.plugins.clone()).await + send_request( + &self.http_client, + &self.subgraph_name, + &self.endpoint, + method, + body, + headers, + self.plugins.clone(), + ) + .await }; // It's important to remove the entry from the map before returning the result. // This ensures that once the OnceCell is set, no future requests can join it. diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index 6f780b76f..cf2f6dc13 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -27,14 +27,19 @@ use vrl::{ }; use crate::{ - execution::client_request_details::ClientRequestDetails, executors::{ + execution::client_request_details::ClientRequestDetails, + executors::{ common::{ - HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc + HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor, + SubgraphExecutorBoxedArc, }, dedupe::{ABuildHasher, SharedResponse}, error::SubgraphExecutorError, http::{HTTPSubgraphExecutor, HttpClient}, - }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError + }, + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{ControlFlowResult, RouterPlugin}, + response::graphql_error::GraphQLError, }; type SubgraphName = String; @@ -119,11 +124,11 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - pub async fn execute<'a, 'req>( + pub async fn execute<'a>( &self, subgraph_name: &str, execution_request: SubgraphExecutionRequest<'a>, - client_request: &ClientRequestDetails<'a, 'req>, + client_request: &ClientRequestDetails<'a>, ) -> HttpExecutionResponse { let mut start_payload = OnSubgraphExecuteStartPayload { subgraph_name: subgraph_name.to_string(), @@ -156,9 +161,7 @@ impl SubgraphExecutorMap { let execution_request = start_payload.execution_request; let execution_result = match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor - .execute(execution_request) - .await, + Ok(Some(executor)) => executor.execute(execution_request).await, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", @@ -174,10 +177,8 @@ impl SubgraphExecutorMap { self.internal_server_error_response("Internal server error".into(), subgraph_name) } }; - - let mut end_payload = OnSubgraphExecuteEndPayload { - execution_result - }; + + let mut end_payload = OnSubgraphExecuteEndPayload { execution_result }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -187,11 +188,11 @@ impl SubgraphExecutorMap { // continue to next callback } ControlFlowResult::EndResponse(response) => { - // TODO: FFIX - return HttpExecutionResponse { - body: response.body.into(), - headers: response.headers, - }; + // TODO: FFIX + return HttpExecutionResponse { + body: response.body.into(), + headers: response.headers, + }; } ControlFlowResult::OnEnd(_) => { unreachable!("End callbacks should not register further end callbacks"); @@ -226,7 +227,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_, '_>, + client_request: &ClientRequestDetails<'_>, ) -> Result, SubgraphExecutorError> { let from_expression = self.get_or_create_executor_from_expression(subgraph_name, client_request)?; @@ -245,7 +246,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor_from_expression( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_, '_>, + client_request: &ClientRequestDetails<'_>, ) -> Result, SubgraphExecutorError> { if let Some(expression) = self.expressions_by_subgraph.get(subgraph_name) { let original_url_value = VrlValue::Bytes(Bytes::from( diff --git a/lib/executor/src/headers/expression.rs b/lib/executor/src/headers/expression.rs index ed63ebfc0..5852f2b75 100644 --- a/lib/executor/src/headers/expression.rs +++ b/lib/executor/src/headers/expression.rs @@ -46,7 +46,7 @@ fn header_map_to_vrl_value(headers: &HeaderMap) -> Value { Value::Object(obj) } -impl From<&RequestExpressionContext<'_, '_>> for Value { +impl From<&RequestExpressionContext<'_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &RequestExpressionContext) -> Self { // .subgraph @@ -65,7 +65,7 @@ impl From<&RequestExpressionContext<'_, '_>> for Value { } } -impl From<&ResponseExpressionContext<'_, '_>> for Value { +impl From<&ResponseExpressionContext<'_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &ResponseExpressionContext) -> Self { // .subgraph diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index 62f9fe701..c617edfa0 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -74,15 +74,15 @@ mod tests { ); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -108,15 +108,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); modify_subgraph_request_headers(&plan, "any", &client_details, &mut out).unwrap(); @@ -155,15 +155,15 @@ mod tests { ); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -193,15 +193,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: Some("MyQuery"), query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -227,15 +227,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); @@ -267,15 +267,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; // For "accounts" subgraph, the specific rule should apply. @@ -311,15 +311,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -376,15 +376,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -440,15 +440,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -497,15 +497,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -555,15 +555,15 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut accumulator = ResponseHeaderAggregator::default(); @@ -614,15 +614,15 @@ mod tests { client_headers.insert(header_name_owned("x-keep"), header_value_owned("hi").into()); let client_details = ClientRequestDetails { - method: &http::Method::POST, - url: &"http://example.com".parse().unwrap(), - headers: &client_headers, + method: http::Method::POST, + url: "http://example.com".parse().unwrap(), + headers: client_headers, operation: OperationDetails { name: None, query: "{ __typename }", kind: "query", }, - jwt: &JwtRequestDetails::Unauthenticated, + jwt: JwtRequestDetails::Unauthenticated, }; let mut out = HeaderMap::new(); diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 637ab0d58..44b6ed8b9 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -45,9 +45,9 @@ pub fn modify_subgraph_request_headers( Ok(()) } -pub struct RequestExpressionContext<'a, 'req> { +pub struct RequestExpressionContext<'a> { pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, + pub client_request: &'a ClientRequestDetails<'a>, } trait ApplyRequestHeader { @@ -117,7 +117,7 @@ impl ApplyRequestHeader for RequestPropagateRegex { ctx: &RequestExpressionContext, output_headers: &mut HeaderMap, ) -> Result<(), HeaderRuleRuntimeError> { - for (header_name, header_value) in ctx.client_request.headers { + for (header_name, header_value) in ctx.client_request.headers.iter() { if is_denied_header(header_name) { continue; } diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 6a5c34444..94019d585 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -50,9 +50,9 @@ pub fn apply_subgraph_response_headers( Ok(()) } -pub struct ResponseExpressionContext<'a, 'req> { +pub struct ResponseExpressionContext<'a> { pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a, 'req>, + pub client_request: &'a ClientRequestDetails<'a>, pub subgraph_headers: &'a HeaderMap, } diff --git a/lib/executor/src/lib.rs b/lib/executor/src/lib.rs index 1f29c192e..bdcbdadc0 100644 --- a/lib/executor/src/lib.rs +++ b/lib/executor/src/lib.rs @@ -10,6 +10,5 @@ pub mod response; pub mod utils; pub mod variables; -pub use execution::plan::execute_query_plan; pub use executors::map::SubgraphExecutorMap; pub use plugins::*; diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index d5400d314..f5e380973 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -13,11 +13,13 @@ pub struct APQPlugin { impl RouterPlugin for APQPlugin { fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> - { + payload: OnGraphQLParamsStartPayload, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { payload.on_end(|mut payload| { - let persisted_query_ext = payload.graphql_params.extensions.as_ref() + let persisted_query_ext = payload + .graphql_params + .extensions + .as_ref() .and_then(|ext| ext.get("persistedQuery")) .and_then(|pq| pq.as_object()); if let Some(persisted_query_ext) = persisted_query_ext { @@ -28,7 +30,10 @@ impl RouterPlugin for APQPlugin { return payload.cont(); } } - let sha256_hash = match persisted_query_ext.get(&"sha256Hash").and_then(|h| h.as_str()) { + let sha256_hash = match persisted_query_ext + .get(&"sha256Hash") + .and_then(|h| h.as_str()) + { Some(h) => h, None => { return payload.cont(); @@ -36,7 +41,8 @@ impl RouterPlugin for APQPlugin { }; if let Some(query_param) = &payload.graphql_params.query { // Store the query in the cache - self.cache.insert(sha256_hash.to_string(), query_param.to_string()); + self.cache + .insert(sha256_hash.to_string(), query_param.to_string()); } else { // Try to get the query from the cache if let Some(cached_query) = self.cache.get(sha256_hash) { diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index 68e3e7092..a6d766a9c 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -1,3 +1,3 @@ +pub mod apq; pub mod response_cache; pub mod subgraph_response_cache; -pub mod apq; \ No newline at end of file diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs new file mode 100644 index 000000000..e69de29bb diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index d9d611307..fc6276643 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -5,7 +5,8 @@ use redis::Commands; use crate::{ execution::plan::PlanExecutionOutput, hooks::{ - on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, }, plugin_trait::{EndPayload, HookResult, StartPayload}, plugins::plugin_trait::RouterPlugin, @@ -87,24 +88,22 @@ impl RouterPlugin for ResponseCachePlugin { } payload.cont() } - fn on_supergraph_reload<'a>(&'a self, payload: OnSupergraphLoadStartPayload) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + fn on_supergraph_reload<'a>( + &'a self, + payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'a, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { // Visit the schema and update ttl_per_type based on some directive - payload - .new_ast - .definitions - .iter() - .for_each(|def| { - if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { - if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { - for directive in &obj_type.directives { - if directive.name == "cacheControl" { - for arg in &directive.arguments { - if arg.0 == "maxAge" { - if let graphql_parser::query::Value::Int(max_age) = &arg.1 { - if let Some(max_age) = max_age.as_i64() { - self.ttl_per_type - .insert(obj_type.name.clone(), max_age as u64); - } + payload.new_ast.definitions.iter().for_each(|def| { + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { + for directive in &obj_type.directives { + if directive.name == "cacheControl" { + for arg in &directive.arguments { + if arg.0 == "maxAge" { + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { + if let Some(max_age) = max_age.as_i64() { + self.ttl_per_type + .insert(obj_type.name.clone(), max_age as u64); } } } @@ -112,7 +111,8 @@ impl RouterPlugin for ResponseCachePlugin { } } } - }); + } + }); payload.cont() } diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 71ff8c1d9..4d192dd39 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -1,6 +1,10 @@ use dashmap::DashMap; -use crate::{executors::{common::HttpExecutionResponse}, hooks::{on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}}, plugin_trait::{ EndPayload, HookResult, RouterPlugin, StartPayload}}; +use crate::{ + executors::common::HttpExecutionResponse, + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; pub struct SubgraphResponseCachePlugin { cache: DashMap, @@ -8,9 +12,9 @@ pub struct SubgraphResponseCachePlugin { impl RouterPlugin for SubgraphResponseCachePlugin { fn on_subgraph_execute<'exec>( - &'exec self, - mut payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { let key = format!( "subgraph_response_cache:{}:{:?}", payload.execution_request.query, payload.execution_request.variables @@ -21,10 +25,10 @@ impl RouterPlugin for SubgraphResponseCachePlugin { payload.execution_result = Some(cached_response.clone()); return payload.cont(); } - payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { + payload.on_end(move |payload: OnSubgraphExecuteEndPayload| { // Here payload.response is not Option self.cache.insert(key, payload.execution_result.clone()); payload.cont() }) } -} \ No newline at end of file +} diff --git a/lib/executor/src/plugins/hooks/mod.rs b/lib/executor/src/plugins/hooks/mod.rs index 453c84c98..64851d0fd 100644 --- a/lib/executor/src/plugins/hooks/mod.rs +++ b/lib/executor/src/plugins/hooks/mod.rs @@ -1,9 +1,9 @@ pub mod on_execute; -pub mod on_supergraph_load; -pub mod on_subgraph_http_request; -pub mod on_http_request; pub mod on_graphql_params; pub mod on_graphql_parse; pub mod on_graphql_validation; +pub mod on_http_request; pub mod on_query_plan; -pub mod on_subgraph_execute; \ No newline at end of file +pub mod on_subgraph_execute; +pub mod on_subgraph_http_request; +pub mod on_supergraph_load; diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index dfcdaceb8..9a77679b9 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,15 +1,16 @@ use std::collections::HashMap; +use std::sync::Arc; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use ntex::web::HttpRequest; use crate::plugin_trait::{EndPayload, StartPayload}; -use crate::response::{value::Value}; use crate::response::graphql_error::GraphQLError; +use crate::response::value::Value; pub struct OnExecuteStartPayload<'exec> { - pub router_http_request: &'exec HttpRequest, - pub query_plan: &'exec QueryPlan, + pub router_http_request: HttpRequest, + pub query_plan: Arc, pub data: Value<'exec>, pub errors: Vec, diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index 5e6ce1c47..a9afabed1 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -1,41 +1,103 @@ +use core::fmt; + use std::collections::HashMap; use ntex::util::Bytes; -use serde::Deserialize; -use serde::Deserializer; +use serde::{de, Deserialize, Deserializer}; use sonic_rs::Value; use crate::plugin_trait::EndPayload; use crate::plugin_trait::StartPayload; -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] +#[derive(Debug, Clone, Default)] pub struct GraphQLParams { pub query: Option, pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] pub variables: HashMap, // TODO: We don't use extensions yet, but we definitely will in the future. #[allow(dead_code)] pub extensions: Option>, } -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) +// Workaround for https://github.com/cloudwego/sonic-rs/issues/114 + +impl<'de> Deserialize<'de> for GraphQLParams { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct GraphQLErrorExtensionsVisitor; + + impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor { + type Value = GraphQLParams; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map for GraphQLErrorExtensions") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut query = None; + let mut operation_name = None; + let mut variables: Option> = None; + let mut extensions: Option> = None; + let mut extra_params = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "query" => { + if query.is_some() { + return Err(de::Error::duplicate_field("query")); + } + query = map.next_value::>()?; + } + "operationName" => { + if operation_name.is_some() { + return Err(de::Error::duplicate_field("operationName")); + } + operation_name = map.next_value::>()?; + } + "variables" => { + if variables.is_some() { + return Err(de::Error::duplicate_field("variables")); + } + variables = map.next_value::>>()?; + } + "extensions" => { + if extensions.is_some() { + return Err(de::Error::duplicate_field("extensions")); + } + extensions = map.next_value::>>()?; + } + other => { + let value: Value = map.next_value()?; + extra_params.insert(other.to_string(), value); + } + } + } + + Ok(GraphQLParams { + query, + operation_name, + variables: variables.unwrap_or_default(), + extensions, + }) + } + } + + deserializer.deserialize_map(GraphQLErrorExtensionsVisitor) + } } -pub struct OnGraphQLParamsStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, +pub struct OnGraphQLParamsStartPayload { + pub router_http_request: ntex::web::HttpRequest, pub body: Bytes, pub graphql_params: Option, } -impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} +impl StartPayload for OnGraphQLParamsStartPayload {} pub struct OnGraphQLParamsEndPayload { pub graphql_params: GraphQLParams, diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index 162a7eee2..df9b4e480 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -1,19 +1,20 @@ use graphql_tools::static_graphql::query::Document; -use crate::{hooks::on_graphql_params::GraphQLParams, plugin_trait::{EndPayload, StartPayload}}; +use crate::{ + hooks::on_graphql_params::GraphQLParams, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnGraphQLParseStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, + pub router_http_request: ntex::web::HttpRequest, pub graphql_params: &'exec GraphQLParams, pub document: Option, } -impl<'exec> StartPayload> for OnGraphQLParseStartPayload<'exec> {} +impl<'exec> StartPayload for OnGraphQLParseStartPayload<'exec> {} -pub struct OnGraphQLParseEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub graphql_params: &'exec GraphQLParams, +pub struct OnGraphQLParseEndPayload { pub document: Document, } -impl<'exec> EndPayload for OnGraphQLParseEndPayload<'exec> {} \ No newline at end of file +impl EndPayload for OnGraphQLParseEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index a789cb5fd..f6bb55004 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -1,13 +1,17 @@ use graphql_tools::{ static_graphql::query::Document, - validation::{rules::{ValidationRule, default_rules_validation_plan}, utils::ValidationError, validate::ValidationPlan}, + validation::{ + rules::{default_rules_validation_plan, ValidationRule}, + utils::ValidationError, + validate::ValidationPlan, + }, }; use hive_router_query_planner::state::supergraph_state::SchemaDocument; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnGraphQLValidationStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, + pub router_http_request: &'exec mut ntex::web::HttpRequest, pub schema: &'exec SchemaDocument, pub document: &'exec Document, default_validation_plan: &'exec ValidationPlan, @@ -15,14 +19,11 @@ pub struct OnGraphQLValidationStartPayload<'exec> { pub errors: Option>, } -impl<'exec> StartPayload> - for OnGraphQLValidationStartPayload<'exec> -{ -} +impl<'exec> StartPayload for OnGraphQLValidationStartPayload<'exec> {} impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn new( - router_http_request: &'exec ntex::web::HttpRequest, + router_http_request: &'exec mut ntex::web::HttpRequest, schema: &'exec SchemaDocument, document: &'exec Document, default_validation_plan: &'exec ValidationPlan, @@ -39,17 +40,17 @@ impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn add_validation_rule(&mut self, rule: Box) { self.new_validation_plan - .get_or_insert_with(|| default_rules_validation_plan()) + .get_or_insert_with(default_rules_validation_plan) .add_rule(rule); } - + pub fn filter_validation_rules(&mut self, mut f: F) where F: FnMut(&Box) -> bool, { let plan = self .new_validation_plan - .get_or_insert_with(|| default_rules_validation_plan()); + .get_or_insert_with(default_rules_validation_plan); plan.rules.retain(|rule| f(rule)); } @@ -61,11 +62,8 @@ impl<'exec> OnGraphQLValidationStartPayload<'exec> { } } -pub struct OnGraphQLValidationEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub schema: &'exec SchemaDocument, - pub document: &'exec Document, +pub struct OnGraphQLValidationEndPayload { pub errors: Vec, } -impl<'exec> EndPayload for OnGraphQLValidationEndPayload<'exec> {} +impl EndPayload for OnGraphQLValidationEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 29a8344e5..a7f6f6bb5 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -13,4 +13,4 @@ pub struct OnHttpResponse<'exec> { pub response: &'exec mut Response, } -impl<'exec> EndPayload for OnHttpResponse<'exec> {} \ No newline at end of file +impl<'exec> EndPayload for OnHttpResponse<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index 39ae3c2d6..fd2089ec6 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -1,9 +1,14 @@ -use hive_router_query_planner::{ast::operation::OperationDefinition, graph::PlannerOverrideContext, planner::{Planner, plan_nodes::QueryPlan}, utils::cancellation::CancellationToken}; +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + graph::PlannerOverrideContext, + planner::{plan_nodes::QueryPlan, Planner}, + utils::cancellation::CancellationToken, +}; use crate::plugin_trait::{EndPayload, StartPayload}; pub struct OnQueryPlanStartPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, + pub router_http_request: &'exec mut ntex::web::HttpRequest, pub filtered_operation_for_plan: &'exec OperationDefinition, pub planner_override_context: PlannerOverrideContext, pub cancellation_token: &'exec CancellationToken, @@ -11,15 +16,10 @@ pub struct OnQueryPlanStartPayload<'exec> { pub planner: &'exec Planner, } -impl<'exec> StartPayload> for OnQueryPlanStartPayload<'exec> {} +impl<'exec> StartPayload for OnQueryPlanStartPayload<'exec> {} -pub struct OnQueryPlanEndPayload<'exec> { - pub router_http_request: &'exec ntex::web::HttpRequest, - pub filtered_operation_for_plan: &'exec OperationDefinition, - pub planner_override_context: PlannerOverrideContext, - pub cancellation_token: &'exec CancellationToken, +pub struct OnQueryPlanEndPayload { pub query_plan: QueryPlan, - pub planner: &'exec Planner, } -impl<'exec> EndPayload for OnQueryPlanEndPayload<'exec> {} \ No newline at end of file +impl EndPayload for OnQueryPlanEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 6a514006f..5a6fcc6a6 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,14 +1,14 @@ - - -use crate::{executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, plugin_trait::{EndPayload, StartPayload}}; - +use crate::{ + executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnSubgraphExecuteStartPayload<'exec> { pub subgraph_name: String, pub execution_request: SubgraphExecutionRequest<'exec>, pub execution_result: Option, -} +} impl<'exec> StartPayload for OnSubgraphExecuteStartPayload<'exec> {} @@ -16,4 +16,4 @@ pub struct OnSubgraphExecuteEndPayload { pub execution_result: HttpExecutionResponse, } -impl<'exec> EndPayload for OnSubgraphExecuteEndPayload {} \ No newline at end of file +impl EndPayload for OnSubgraphExecuteEndPayload {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs index f7c798ed7..1b50f001b 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_http_request.rs @@ -1,10 +1,11 @@ use bytes::Bytes; -use http::{Request}; +use http::Request; use http_body_util::Full; use crate::{ - executors::{dedupe::SharedResponse}, plugin_trait::{EndPayload, StartPayload}} -; + executors::dedupe::SharedResponse, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnSubgraphHttpRequestPayload<'exec> { pub subgraph_name: &'exec str, @@ -18,7 +19,7 @@ pub struct OnSubgraphHttpRequestPayload<'exec> { impl<'exec> StartPayload for OnSubgraphHttpRequestPayload<'exec> {} pub struct OnSubgraphHttpResponsePayload { - pub response: SharedResponse, + pub response: SharedResponse, } -impl<'exec> EndPayload for OnSubgraphHttpResponsePayload {} +impl EndPayload for OnSubgraphHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_supergraph_load.rs b/lib/executor/src/plugins/hooks/on_supergraph_load.rs index e4e68ca35..21dfbf5a5 100644 --- a/lib/executor/src/plugins/hooks/on_supergraph_load.rs +++ b/lib/executor/src/plugins/hooks/on_supergraph_load.rs @@ -1,11 +1,14 @@ use std::sync::Arc; +use arc_swap::ArcSwap; use graphql_tools::static_graphql::schema::Document; -use hive_router_query_planner::{planner::Planner}; -use arc_swap::{ArcSwap}; - -use crate::{SubgraphExecutorMap, introspection::schema::SchemaMetadata, plugin_trait::{EndPayload, StartPayload}}; +use hive_router_query_planner::planner::Planner; +use crate::{ + introspection::schema::SchemaMetadata, + plugin_trait::{EndPayload, StartPayload}, + SubgraphExecutorMap, +}; pub struct SupergraphData { pub metadata: SchemaMetadata, @@ -24,4 +27,4 @@ pub struct OnSupergraphLoadEndPayload { pub new_supergraph_data: SupergraphData, } -impl EndPayload for OnSupergraphLoadEndPayload {} \ No newline at end of file +impl EndPayload for OnSupergraphLoadEndPayload {} diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 6c35286af..02490fb5e 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,3 +1,3 @@ pub mod examples; +pub mod hooks; pub mod plugin_trait; -pub mod hooks; \ No newline at end of file diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index d56856652..c502ba087 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -86,30 +86,27 @@ pub trait RouterPlugin { } fn on_graphql_params<'exec>( &'exec self, - start_payload: OnGraphQLParamsStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + start_payload: OnGraphQLParamsStartPayload, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { start_payload.cont() } fn on_graphql_parse<'exec>( &self, start_payload: OnGraphQLParseStartPayload<'exec>, - ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload<'exec>> { + ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { start_payload.cont() } fn on_graphql_validation<'exec>( &self, start_payload: OnGraphQLValidationStartPayload<'exec>, - ) -> HookResult< - 'exec, - OnGraphQLValidationStartPayload<'exec>, - OnGraphQLValidationEndPayload<'exec>, - > { + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { start_payload.cont() } fn on_query_plan<'exec>( &self, start_payload: OnQueryPlanStartPayload<'exec>, - ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload<'exec>> { + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { start_payload.cont() } fn on_execute<'exec>( @@ -121,15 +118,13 @@ pub trait RouterPlugin { fn on_subgraph_execute<'exec>( &'exec self, start_payload: OnSubgraphExecuteStartPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> - { + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { start_payload.cont() } fn on_subgraph_http_request<'exec>( &'exec self, start_payload: OnSubgraphHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> - { + ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { start_payload.cont() } fn on_supergraph_reload<'exec>( From 9ebd0fb895a8bf24ed0e847d2153925cb5f8be81 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Thu, 20 Nov 2025 18:09:14 +0300 Subject: [PATCH 09/15] Async Hooks and Shared Context --- Cargo.lock | 2 + bin/router/Cargo.toml | 2 + bin/router/src/lib.rs | 10 +- bin/router/src/pipeline/execution.rs | 17 ++- bin/router/src/pipeline/mod.rs | 109 ++++++++++-------- bin/router/src/pipeline/parser.rs | 35 +++--- .../src/pipeline/progressive_override.rs | 8 +- bin/router/src/pipeline/query_plan.rs | 21 ++-- bin/router/src/pipeline/validation.rs | 8 +- bin/router/src/plugins/mod.rs | 1 + bin/router/src/plugins/plugins_service.rs | 109 ++++++++++++++++++ .../src/execution/client_request_details.rs | 12 +- lib/executor/src/execution/plan.rs | 76 ++++++------ lib/executor/src/executors/http.rs | 2 +- lib/executor/src/executors/map.rs | 16 ++- lib/executor/src/headers/expression.rs | 4 +- lib/executor/src/headers/mod.rs | 72 ++++++------ lib/executor/src/headers/request.rs | 6 +- lib/executor/src/headers/response.rs | 8 +- lib/executor/src/plugins/examples/apq.rs | 7 +- .../src/plugins/examples/response_cache.rs | 3 +- .../examples/subgraph_response_cache.rs | 3 +- lib/executor/src/plugins/hooks/on_execute.rs | 8 +- .../src/plugins/hooks/on_graphql_params.rs | 9 +- .../src/plugins/hooks/on_graphql_parse.rs | 4 +- .../plugins/hooks/on_graphql_validation.rs | 13 ++- .../src/plugins/hooks/on_http_request.rs | 22 ++-- .../src/plugins/hooks/on_query_plan.rs | 8 +- .../src/plugins/hooks/on_subgraph_execute.rs | 4 + lib/executor/src/plugins/mod.rs | 1 + lib/executor/src/plugins/plugin_context.rs | 43 +++++++ lib/executor/src/plugins/plugin_trait.rs | 27 ++--- 32 files changed, 424 insertions(+), 246 deletions(-) create mode 100644 bin/router/src/plugins/mod.rs create mode 100644 bin/router/src/plugins/plugins_service.rs create mode 100644 lib/executor/src/plugins/plugin_context.rs diff --git a/Cargo.lock b/Cargo.lock index b81088873..52b09935b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2004,6 +2004,8 @@ dependencies = [ "mimalloc", "moka", "ntex", + "ntex-service", + "ntex-util", "rand 0.9.2", "regex-automata", "reqwest", diff --git a/bin/router/Cargo.toml b/bin/router/Cargo.toml index 4969b0a46..15d5e2051 100644 --- a/bin/router/Cargo.toml +++ b/bin/router/Cargo.toml @@ -53,3 +53,5 @@ tokio-util = "0.7.16" cookie = "0.18.1" regex-automata = "0.4.10" arc-swap = "1.7.1" +ntex-util = "2.15.0" +ntex-service = "3.5.0" \ No newline at end of file diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 416021b3a..dec90d730 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -4,6 +4,7 @@ mod http_utils; mod jwt; mod logger; mod pipeline; +mod plugins; mod schema_state; mod shared_state; mod supergraph; @@ -24,6 +25,7 @@ use crate::{ graphql_request_handler, header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, }, + plugins::plugins_service::PluginService, }; pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState}; @@ -56,8 +58,8 @@ async fn graphql_endpoint_handler( let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR); - let result = match graphql_request_handler( - req, + let mut response = match graphql_request_handler( + &req, body_bytes, supergraph, app_state.get_ref().clone(), @@ -69,9 +71,6 @@ async fn graphql_endpoint_handler( Err(error) => return PipelineError { accept_ok, error }.into(), }; - let mut response = result.result; - let req = result.request; - // Apply CORS headers to the final response if CORS is configured. if let Some(cors) = app_state.cors_runtime.as_ref() { cors.set_headers(&req, response.headers_mut()); @@ -99,6 +98,7 @@ pub async fn router_entrypoint() -> Result<(), Box> { let maybe_error = web::HttpServer::new(move || { web::App::new() + .wrap(PluginService) .state(shared_state.clone()) .state(schema_state.clone()) .configure(configure_ntex_app) diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index e69a89179..e40fc02c3 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -7,11 +7,10 @@ use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::shared_state::RouterSharedState; use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan; -use hive_router_plan_executor::execution::plan::{ - PlanExecutionOutput, QueryPlanExecutionContext, ResultWithRequest, -}; +use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext}; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; use hive_router_plan_executor::introspection::resolve::IntrospectionContext; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use http::HeaderName; use ntex::web::HttpRequest; @@ -26,15 +25,16 @@ enum ExposeQueryPlanMode { } #[inline] -pub async fn execute_plan( - req: HttpRequest, +pub async fn execute_plan<'exec, 'req>( + req: &HttpRequest, supergraph: &SupergraphData, app_state: Arc, normalized_payload: Arc, query_plan_payload: Arc, variable_payload: &CoerceVariablesPayload, - client_request_details: &ClientRequestDetails<'_>, -) -> Result, PipelineErrorVariant> { + client_request_details: &ClientRequestDetails<'exec, 'req>, + plugin_manager: PluginManager<'req>, +) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; if app_state.router_config.query_planner.allow_expose { @@ -86,7 +86,7 @@ pub async fn execute_plan( }; let ctx = QueryPlanExecutionContext { - router_http_request: req, + plugin_manager: &plugin_manager, query_plan: query_plan_payload, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, @@ -97,7 +97,6 @@ pub async fn execute_plan( operation_type_name: normalized_payload.root_type_name, jwt_auth_forwarding, executors: &supergraph.subgraph_executor_map, - plugins: &app_state.plugins, }; ctx.execute_query_plan().await.map_err(|err| { diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index b97f7e67f..860a13c91 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -3,12 +3,13 @@ use std::sync::Arc; use hive_router_plan_executor::{ execution::{ client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails}, - plan::{PlanExecutionOutput, ResultWithRequest, WithResult}, + plan::PlanExecutionOutput, }, hooks::{ on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, on_supergraph_load::SupergraphData, }, + plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, plugin_trait::ControlFlowResult, }; use hive_router_query_planner::{ @@ -59,28 +60,26 @@ static GRAPHIQL_HTML: &str = include_str!("../../static/graphiql.html"); #[inline] pub async fn graphql_request_handler( - req: HttpRequest, + req: &HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, shared_state: Arc, schema_state: Arc, -) -> Result, PipelineErrorVariant> { +) -> Result { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { - return Ok(req.with_result( - web::HttpResponse::Ok() - .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) - .body(GRAPHIQL_HTML), - )); + return Ok(web::HttpResponse::Ok() + .header(CONTENT_TYPE, *TEXT_HTML_CONTENT_TYPE) + .body(GRAPHIQL_HTML)); } else { - return Ok(req.with_result(web::HttpResponse::NotFound().into())); + return Ok(web::HttpResponse::NotFound().into()); } } let jwt_context = if let Some(jwt) = &shared_state.jwt_auth_runtime { - match jwt.validate_request(&req) { + match jwt.validate_request(req) { Ok(jwt_context) => jwt_context, - Err(err) => return Ok(req.with_result(err.make_response())), + Err(err) => return Ok(err.make_response()), } } else { None @@ -93,16 +92,36 @@ pub async fn graphql_request_handler( &APPLICATION_JSON }; - let execution_result_with_req = execute_pipeline( + let plugin_context = req + .extensions() + .get::>() + .cloned() + .expect("Plugin manager should be loaded"); + + let plugin_manager = PluginManager { + plugins: shared_state.plugins.clone(), + router_http_request: RouterHttpRequest { + uri: req.uri(), + method: req.method(), + version: req.version(), + headers: req.headers(), + match_info: req.match_info(), + query_string: req.query_string(), + path: req.path(), + }, + context: plugin_context, + }; + + let response = execute_pipeline( req, body_bytes, supergraph, shared_state, schema_state, jwt_context, + plugin_manager, ) .await?; - let response = execution_result_with_req.result; let response_bytes = Bytes::from(response.body); let response_headers = response.headers; @@ -113,41 +132,39 @@ pub async fn graphql_request_handler( } } - Ok(execution_result_with_req.request.with_result( - response_builder - .header(http::header::CONTENT_TYPE, response_content_type) - .body(response_bytes), - )) + Ok(response_builder + .header(http::header::CONTENT_TYPE, response_content_type) + .body(response_bytes)) } #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline( - req: HttpRequest, +pub async fn execute_pipeline<'req>( + req: &'req HttpRequest, body: Bytes, supergraph: &SupergraphData, shared_state: Arc, schema_state: Arc, jwt_context: Option, -) -> Result, PipelineErrorVariant> { - perform_csrf_prevention(&req, &shared_state.router_config.csrf)?; + plugin_manager: PluginManager<'req>, +) -> Result { + perform_csrf_prevention(req, &shared_state.router_config.csrf)?; /* Handle on_deserialize hook in the plugins - START */ let mut deserialization_end_callbacks = vec![]; let mut deserialization_payload: OnGraphQLParamsStartPayload = OnGraphQLParamsStartPayload { - router_http_request: req, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, body, graphql_params: None, }; for plugin in shared_state.plugins.as_ref() { - let result = plugin.on_graphql_params(deserialization_payload); + let result = plugin.on_graphql_params(deserialization_payload).await; deserialization_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(deserialization_payload - .router_http_request - .with_result(response)); + return Ok(response); } ControlFlowResult::OnEnd(callback) => { deserialization_end_callbacks.push(callback); @@ -155,11 +172,8 @@ pub async fn execute_pipeline( } } let graphql_params = deserialization_payload.graphql_params.unwrap_or_else(|| { - deserialize_graphql_params( - &deserialization_payload.router_http_request, - deserialization_payload.body, - ) - .expect("Failed to parse execution request") + deserialize_graphql_params(req, deserialization_payload.body) + .expect("Failed to parse execution request") }); let mut payload = OnGraphQLParamsEndPayload { graphql_params }; @@ -169,9 +183,7 @@ pub async fn execute_pipeline( match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(deserialization_payload - .router_http_request - .with_result(response)); + return Ok(response); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -182,25 +194,22 @@ pub async fn execute_pipeline( let mut graphql_params = payload.graphql_params; /* Handle on_deserialize hook in the plugins - END */ - let req = deserialization_payload.router_http_request; let parser_result = - parse_operation_with_cache(req, shared_state.clone(), &graphql_params).await?; - - let mut req = parser_result.request; + parse_operation_with_cache(shared_state.clone(), &graphql_params, &plugin_manager).await?; - let parser_payload = match parser_result.result { + let parser_payload = match parser_result { ParseResult::Payload(payload) => payload, ParseResult::Response(response) => { - return Ok(req.with_result(response)); + return Ok(response); } }; validate_operation_with_cache( - &mut req, supergraph, schema_state.clone(), shared_state.clone(), &parser_payload, + &plugin_manager, ) .await?; @@ -213,7 +222,7 @@ pub async fn execute_pipeline( .await?; let variable_payload = - coerce_request_variables(&req, supergraph, &mut graphql_params, &normalize_payload)?; + coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); @@ -231,9 +240,9 @@ pub async fn execute_pipeline( }; let client_request_details = ClientRequestDetails { - method: req.method().clone(), - url: req.uri().clone(), - headers: req.headers().clone(), + method: req.method(), + url: req.uri(), + headers: req.headers(), operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -254,20 +263,19 @@ pub async fn execute_pipeline( .map_err(PipelineErrorVariant::LabelEvaluationError)?; let query_plan_result = plan_operation_with_cache( - req, supergraph, schema_state.clone(), normalize_payload.clone(), &progressive_override_ctx, &query_plan_cancellation_token, shared_state.clone(), + &plugin_manager, ) .await?; - let req = query_plan_result.request; - let query_plan_payload = match query_plan_result.result { + let query_plan_payload = match query_plan_result { QueryPlanResult::QueryPlan(plan) => plan, QueryPlanResult::Response(response) => { - return Ok(req.with_result(response)); + return Ok(response); } }; @@ -279,6 +287,7 @@ pub async fn execute_pipeline( query_plan_payload, &variable_payload, &client_request_details, + plugin_manager, ) .await?; diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index e6194fbd5..0a11ab2aa 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -2,16 +2,14 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use graphql_parser::query::Document; -use hive_router_plan_executor::execution::plan::{ - PlanExecutionOutput, ResultWithRequest, WithResult, -}; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_graphql_parse::{ OnGraphQLParseEndPayload, OnGraphQLParseStartPayload, }; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::utils::parsing::safe_parse_operation; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; use crate::pipeline::deserialize_graphql_params::GetQueryStr; @@ -31,11 +29,11 @@ pub enum ParseResult { } #[inline] -pub async fn parse_operation_with_cache( - req: HttpRequest, +pub async fn parse_operation_with_cache<'req>( app_state: Arc, graphql_params: &GraphQLParams, -) -> Result, PipelineErrorVariant> { + plugin_manager: &PluginManager<'req>, +) -> Result { let cache_key = { let mut hasher = Xxh3::new(); graphql_params.query.hash(&mut hasher); @@ -43,7 +41,8 @@ pub async fn parse_operation_with_cache( }; /* Handle on_graphql_parse hook in the plugins - START */ let mut start_payload = OnGraphQLParseStartPayload { - router_http_request: req, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, graphql_params, document: None, }; @@ -54,16 +53,14 @@ pub async fn parse_operation_with_cache( } else { let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { - let result = plugin.on_graphql_parse(start_payload); + let result = plugin.on_graphql_parse(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { // continue to next plugin } ControlFlowResult::EndResponse(response) => { - return Ok(start_payload - .router_http_request - .with_result(ParseResult::Response(response))); + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(callback) => { // store the callback to be called later @@ -93,9 +90,7 @@ pub async fn parse_operation_with_cache( // continue to next callback } ControlFlowResult::EndResponse(response) => { - return Ok(start_payload - .router_http_request - .with_result(ParseResult::Response(response))); + return Ok(ParseResult::Response(response)); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -114,10 +109,8 @@ pub async fn parse_operation_with_cache( parsed_arc }; - Ok(start_payload - .router_http_request - .with_result(ParseResult::Payload(GraphQLParserPayload { - parsed_operation, - cache_key, - }))) + Ok(ParseResult::Payload(GraphQLParserPayload { + parsed_operation, + cache_key, + })) } diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index dc28a0d60..d0b09c183 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec>( +pub fn request_override_context<'exec, 'req>( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec>, + client_request_details: &ClientRequestDetails<'exec, 'req>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; @@ -158,9 +158,9 @@ impl OverrideLabelsEvaluator { }) } - pub(crate) fn evaluate<'exec>( + pub(crate) fn evaluate<'exec, 'req>( &self, - client_request: &ClientRequestDetails<'exec>, + client_request: &ClientRequestDetails<'exec, 'req>, ) -> Result, LabelEvaluationError> { let mut active_flags = self.static_enabled_labels.clone(); diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 807946296..156b8ef9f 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -6,18 +6,16 @@ use crate::pipeline::normalize::GraphQLNormalizationPayload; use crate::pipeline::progressive_override::{RequestOverrideContext, StableOverrideContext}; use crate::schema_state::SchemaState; use crate::RouterSharedState; -use hive_router_plan_executor::execution::plan::{ - PlanExecutionOutput, ResultWithRequest, WithResult, -}; +use hive_router_plan_executor::execution::plan::PlanExecutionOutput; use hive_router_plan_executor::hooks::on_query_plan::{ OnQueryPlanEndPayload, OnQueryPlanStartPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_plan_executor::plugin_trait::ControlFlowResult; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use hive_router_query_planner::planner::PlannerError; use hive_router_query_planner::utils::cancellation::CancellationToken; -use ntex::web::HttpRequest; use xxhash_rust::xxh3::Xxh3; pub enum QueryPlanResult { @@ -31,15 +29,15 @@ pub enum QueryPlanGetterError { } #[inline] -pub async fn plan_operation_with_cache( - mut req: HttpRequest, +pub async fn plan_operation_with_cache<'req>( supergraph: &SupergraphData, schema_state: Arc, normalized_operation: Arc, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, app_state: Arc, -) -> Result, PipelineErrorVariant> { + plugin_manager: &PluginManager<'req>, +) -> Result { let stable_override_context = StableOverrideContext::new(&supergraph.planner.supergraph, request_override_context); @@ -61,7 +59,8 @@ pub async fn plan_operation_with_cache( /* Handle on_query_plan hook in the plugins - START */ let mut start_payload = OnQueryPlanStartPayload { - router_http_request: &mut req, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, filtered_operation_for_plan, planner_override_context: (&request_override_context.clone()).into(), cancellation_token, @@ -71,7 +70,7 @@ pub async fn plan_operation_with_cache( let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { - let result = plugin.on_query_plan(start_payload); + let result = plugin.on_query_plan(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { @@ -121,11 +120,11 @@ pub async fn plan_operation_with_cache( .await; match plan_result { - Ok(plan) => Ok(req.with_result(QueryPlanResult::QueryPlan(plan))), + Ok(plan) => Ok(QueryPlanResult::QueryPlan(plan)), Err(e) => match e.as_ref() { QueryPlanGetterError::Planner(e) => Err(PipelineErrorVariant::PlannerError(e.clone())), QueryPlanGetterError::Response(response) => { - Ok(req.with_result(QueryPlanResult::Response(response.clone()))) + Ok(QueryPlanResult::Response(response.clone())) } }, } diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index 3efbddcca..afe833656 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -10,17 +10,17 @@ use hive_router_plan_executor::hooks::on_graphql_validation::{ OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, }; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; +use hive_router_plan_executor::plugin_context::PluginManager; use hive_router_plan_executor::plugin_trait::ControlFlowResult; -use ntex::web::HttpRequest; use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( - req: &mut HttpRequest, supergraph: &SupergraphData, schema_state: Arc, app_state: Arc, parser_payload: &GraphQLParserPayload, + plugin_manager: &PluginManager<'_>, ) -> Result, PipelineErrorVariant> { let consumer_schema_ast = &supergraph.planner.consumer_schema.document; @@ -45,14 +45,14 @@ pub async fn validate_operation_with_cache( /* Handle on_graphql_validate hook in the plugins - START */ let mut start_payload = OnGraphQLValidationStartPayload::new( - req, + plugin_manager, consumer_schema_ast, &parser_payload.parsed_operation, &app_state.validation_plan, ); let mut on_end_callbacks = vec![]; for plugin in app_state.plugins.as_ref() { - let result = plugin.on_graphql_validation(start_payload); + let result = plugin.on_graphql_validation(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { diff --git a/bin/router/src/plugins/mod.rs b/bin/router/src/plugins/mod.rs new file mode 100644 index 000000000..3753246b2 --- /dev/null +++ b/bin/router/src/plugins/mod.rs @@ -0,0 +1 @@ +pub mod plugins_service; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs new file mode 100644 index 000000000..49204b662 --- /dev/null +++ b/bin/router/src/plugins/plugins_service.rs @@ -0,0 +1,109 @@ +use std::sync::Arc; + +use hive_router_plan_executor::{ + hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + plugin_context::PluginContext, + plugin_trait::ControlFlowResult, +}; +use ntex::{ + service::{Service, ServiceCtx}, + web::{self, DefaultError}, + Middleware, +}; + +use crate::RouterSharedState; + +pub struct PluginService; + +impl Middleware for PluginService { + type Service = PluginMiddleware; + + fn create(&self, service: S) -> Self::Service { + PluginMiddleware { service } + } +} + +pub struct PluginMiddleware { + // This is special: We need this to avoid lifetime issues. + service: S, +} + +impl Service> for PluginMiddleware +where + S: Service, Response = web::WebResponse, Error = web::Error>, +{ + type Response = web::WebResponse; + type Error = S::Error; + + ntex::forward_ready!(service); + + async fn call( + &self, + req: web::WebRequest, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + let plugins = req + .app_state::>() + .map(|shared_state| shared_state.plugins.clone()); + + if let Some(plugins) = plugins { + let plugin_context = Arc::new(PluginContext::default()); + req.extensions_mut().insert(plugin_context.clone()); + let mut start_payload = OnHttpRequestPayload { + router_http_request: req, + context: &plugin_context, + response: None, + }; + + let mut on_end_callbacks = vec![]; + + for plugin in plugins.iter() { + let result = plugin.on_http_request(start_payload); + start_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next plugin + } + ControlFlowResult::OnEnd(callback) => { + on_end_callbacks.push(callback); + } + ControlFlowResult::EndResponse(_response) => { + // Short-circuit the request with the provided response + unimplemented!(); + } + } + } + + let req = start_payload.router_http_request; + + let response = match start_payload.response { + Some(response) => response, + None => ctx.call(&self.service, req).await?, + }; + + let mut end_payload = OnHttpResponsePayload { response }; + + for callback in on_end_callbacks.into_iter().rev() { + let result = callback(end_payload); + end_payload = result.payload; + match result.control_flow { + ControlFlowResult::Continue => { + // continue to next callback + } + ControlFlowResult::EndResponse(_response) => { + // Short-circuit the request with the provided response + unimplemented!() + } + ControlFlowResult::OnEnd(_) => { + // This should not happen + unreachable!(); + } + } + } + + return Ok(end_payload.response); + } + + ctx.call(&self.service, req).await + } +} diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 71f28b746..20b7dcf98 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -13,10 +13,10 @@ pub struct OperationDetails<'exec> { pub kind: &'static str, } -pub struct ClientRequestDetails<'exec> { - pub method: Method, - pub url: http::Uri, - pub headers: NtexHeaderMap, +pub struct ClientRequestDetails<'exec, 'req> { + pub method: &'req Method, + pub url: &'req http::Uri, + pub headers: &'req NtexHeaderMap, pub operation: OperationDetails<'exec>, pub jwt: JwtRequestDetails, } @@ -31,10 +31,10 @@ pub enum JwtRequestDetails { Unauthenticated, } -impl From<&ClientRequestDetails<'_>> for Value { +impl From<&ClientRequestDetails<'_, '_>> for Value { fn from(details: &ClientRequestDetails) -> Self { // .request.headers - let headers_value = client_header_map_to_vrl_value(&details.headers); + let headers_value = client_header_map_to_vrl_value(details.headers); // .request.url let url_value = Self::Object(BTreeMap::from([ diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 5ea8bd758..bbe5ba836 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -10,7 +10,6 @@ use hive_router_query_planner::planner::plan_nodes::{ QueryPlan, SequenceNode, }; use http::HeaderMap; -use ntex::web::HttpRequest; use serde::Deserialize; use sonic_rs::ValueRef; @@ -36,7 +35,8 @@ use crate::{ resolve::{resolve_introspection, IntrospectionContext}, schema::SchemaMetadata, }, - plugin_trait::{ControlFlowResult, RouterPlugin}, + plugin_context::PluginManager, + plugin_trait::ControlFlowResult, projection::{ plan::FieldProjectionPlan, request::{project_requires, RequestProjectionContext}, @@ -54,15 +54,14 @@ use crate::{ }, }; -pub struct QueryPlanExecutionContext<'exec> { - pub router_http_request: HttpRequest, - pub plugins: &'exec Vec>, +pub struct QueryPlanExecutionContext<'exec, 'req> { + pub plugin_manager: &'exec PluginManager<'exec>, pub query_plan: Arc, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, pub extensions: Option>, - pub client_request: &'exec ClientRequestDetails<'exec>, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, pub introspection_context: &'exec IntrospectionContext<'exec, 'static>, pub operation_type_name: &'exec str, pub executors: &'exec SubgraphExecutorMap, @@ -75,25 +74,8 @@ pub struct PlanExecutionOutput { pub headers: HeaderMap, } -pub struct ResultWithRequest { - pub result: T, - pub request: HttpRequest, -} - -pub trait WithResult { - fn with_result(self, result: T) -> ResultWithRequest; -} - -impl WithResult for HttpRequest { - fn with_result(self, result: T) -> ResultWithRequest { - ResultWithRequest { result, request: self } - } -} - -impl<'exec> QueryPlanExecutionContext<'exec> { - pub async fn execute_query_plan( - self, - ) -> Result, PlanExecutionError> { +impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { + pub async fn execute_query_plan(self) -> Result { let init_value = if let Some(introspection_query) = self.introspection_context.query { resolve_introspection(introspection_query, self.introspection_context) } else { @@ -103,8 +85,9 @@ impl<'exec> QueryPlanExecutionContext<'exec> { let dedupe_subgraph_requests = self.operation_type_name == "Query"; let mut start_payload = OnExecuteStartPayload { - router_http_request: self.router_http_request, - query_plan: self.query_plan, + router_http_request: &self.plugin_manager.router_http_request, + context: &self.plugin_manager.context, + query_plan: &self.query_plan, data: init_value, errors: Vec::new(), extensions: self.extensions.clone(), @@ -114,13 +97,13 @@ impl<'exec> QueryPlanExecutionContext<'exec> { let mut on_end_callbacks = vec![]; - for plugin in self.plugins { - let result = plugin.on_execute(start_payload); + for plugin in self.plugin_manager.plugins.iter() { + let result = plugin.on_execute(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } ControlFlowResult::EndResponse(response) => { - return Ok(start_payload.router_http_request.with_result(response)); + return Ok(response); } ControlFlowResult::OnEnd(callback) => { on_end_callbacks.push(callback); @@ -132,7 +115,7 @@ impl<'exec> QueryPlanExecutionContext<'exec> { let init_value = start_payload.data; - let mut exec_ctx = ExecutionContext::new(&query_plan, init_value); + let mut exec_ctx = ExecutionContext::new(query_plan, init_value); let executor = Executor::new( self.variable_values, self.executors, @@ -142,6 +125,7 @@ impl<'exec> QueryPlanExecutionContext<'exec> { self.jwt_auth_forwarding, // Deduplicate subgraph requests only if the operation type is a query self.operation_type_name == "Query", + self.plugin_manager, ); if query_plan.node.is_some() { @@ -170,7 +154,7 @@ impl<'exec> QueryPlanExecutionContext<'exec> { match result.control_flow { ControlFlowResult::Continue => { /* continue to next callback */ } ControlFlowResult::EndResponse(output) => { - return Ok(start_payload.router_http_request.with_result(output)); + return Ok(output); } ControlFlowResult::OnEnd(_) => { // on_end callbacks should not return OnEnd again @@ -193,23 +177,22 @@ impl<'exec> QueryPlanExecutionContext<'exec> { affected_path: || None, })?; - Ok(start_payload - .router_http_request - .with_result(PlanExecutionOutput { - body, - headers: response_headers, - })) + Ok(PlanExecutionOutput { + body, + headers: response_headers, + }) } } -pub struct Executor<'exec> { +pub struct Executor<'exec, 'req> { variable_values: &'exec Option>, schema_metadata: &'exec SchemaMetadata, executors: &'exec SubgraphExecutorMap, - client_request: &'exec ClientRequestDetails<'exec>, + client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_manager: &'exec PluginManager<'exec>, } struct ConcurrencyScope<'exec, T> { @@ -296,15 +279,16 @@ struct PreparedFlattenData { representation_hash_to_index: HashMap, } -impl<'exec> Executor<'exec> { +impl<'exec, 'req> Executor<'exec, 'req> { pub fn new( variable_values: &'exec Option>, executors: &'exec SubgraphExecutorMap, schema_metadata: &'exec SchemaMetadata, - client_request: &'exec ClientRequestDetails<'exec>, + client_request: &'exec ClientRequestDetails<'exec, 'req>, headers_plan: &'exec HeaderRulesPlan, jwt_forwarding_plan: Option, dedupe_subgraph_requests: bool, + plugin_manager: &'exec PluginManager<'exec>, ) -> Self { Executor { variable_values, @@ -314,6 +298,7 @@ impl<'exec> Executor<'exec> { headers_plan, dedupe_subgraph_requests, jwt_forwarding_plan, + plugin_manager, } } @@ -803,7 +788,12 @@ impl<'exec> Executor<'exec> { subgraph_name: node.service_name.clone(), response: self .executors - .execute(&node.service_name, subgraph_request, self.client_request) + .execute( + &node.service_name, + subgraph_request, + self.client_request, + self.plugin_manager, + ) .await .into(), })) diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index a5a5ed263..9e2770965 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -191,7 +191,7 @@ async fn send_request( let mut on_end_callbacks = vec![]; for plugin in plugins.as_ref() { - let result = plugin.on_subgraph_http_request(start_payload); + let result = plugin.on_subgraph_http_request(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { /* continue to next plugin */ } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index cf2f6dc13..cf6033040 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -38,6 +38,7 @@ use crate::{ http::{HTTPSubgraphExecutor, HttpClient}, }, hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_context::PluginManager, plugin_trait::{ControlFlowResult, RouterPlugin}, response::graphql_error::GraphQLError, }; @@ -124,13 +125,16 @@ impl SubgraphExecutorMap { Ok(subgraph_executor_map) } - pub async fn execute<'a>( + pub async fn execute<'exec, 'req>( &self, subgraph_name: &str, - execution_request: SubgraphExecutionRequest<'a>, - client_request: &ClientRequestDetails<'a>, + execution_request: SubgraphExecutionRequest<'exec>, + client_request: &ClientRequestDetails<'exec, 'req>, + plugin_manager: &PluginManager<'req>, ) -> HttpExecutionResponse { let mut start_payload = OnSubgraphExecuteStartPayload { + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, subgraph_name: subgraph_name.to_string(), execution_request, execution_result: None, @@ -139,7 +143,7 @@ impl SubgraphExecutorMap { let mut on_end_callbacks = vec![]; for plugin in self.plugins.as_ref() { - let result = plugin.on_subgraph_execute(start_payload); + let result = plugin.on_subgraph_execute(start_payload).await; start_payload = result.payload; match result.control_flow { ControlFlowResult::Continue => { @@ -227,7 +231,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_>, + client_request: &ClientRequestDetails<'_, '_>, ) -> Result, SubgraphExecutorError> { let from_expression = self.get_or_create_executor_from_expression(subgraph_name, client_request)?; @@ -246,7 +250,7 @@ impl SubgraphExecutorMap { fn get_or_create_executor_from_expression( &self, subgraph_name: &str, - client_request: &ClientRequestDetails<'_>, + client_request: &ClientRequestDetails<'_, '_>, ) -> Result, SubgraphExecutorError> { if let Some(expression) = self.expressions_by_subgraph.get(subgraph_name) { let original_url_value = VrlValue::Bytes(Bytes::from( diff --git a/lib/executor/src/headers/expression.rs b/lib/executor/src/headers/expression.rs index 5852f2b75..ed63ebfc0 100644 --- a/lib/executor/src/headers/expression.rs +++ b/lib/executor/src/headers/expression.rs @@ -46,7 +46,7 @@ fn header_map_to_vrl_value(headers: &HeaderMap) -> Value { Value::Object(obj) } -impl From<&RequestExpressionContext<'_>> for Value { +impl From<&RequestExpressionContext<'_, '_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &RequestExpressionContext) -> Self { // .subgraph @@ -65,7 +65,7 @@ impl From<&RequestExpressionContext<'_>> for Value { } } -impl From<&ResponseExpressionContext<'_>> for Value { +impl From<&ResponseExpressionContext<'_, '_>> for Value { /// NOTE: If performance becomes an issue, consider pre-computing parts of this context that do not change fn from(ctx: &ResponseExpressionContext) -> Self { // .subgraph diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index c617edfa0..0338bf6df 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -74,9 +74,9 @@ mod tests { ); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -108,9 +108,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -155,9 +155,9 @@ mod tests { ); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -193,9 +193,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: Some("MyQuery"), query: "{ __typename }", @@ -227,9 +227,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -267,9 +267,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -311,9 +311,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -376,9 +376,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -440,9 +440,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -497,9 +497,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -555,9 +555,9 @@ mod tests { let plan = compile_headers_plan(&config.headers).unwrap(); let client_headers = NtexHeaderMap::new(); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", @@ -614,9 +614,9 @@ mod tests { client_headers.insert(header_name_owned("x-keep"), header_value_owned("hi").into()); let client_details = ClientRequestDetails { - method: http::Method::POST, - url: "http://example.com".parse().unwrap(), - headers: client_headers, + method: &http::Method::POST, + url: &"http://example.com".parse().unwrap(), + headers: &client_headers, operation: OperationDetails { name: None, query: "{ __typename }", diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 44b6ed8b9..7f362be73 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -45,9 +45,9 @@ pub fn modify_subgraph_request_headers( Ok(()) } -pub struct RequestExpressionContext<'a> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a>, +pub struct RequestExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, } trait ApplyRequestHeader { diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 94019d585..b4942837f 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -50,10 +50,10 @@ pub fn apply_subgraph_response_headers( Ok(()) } -pub struct ResponseExpressionContext<'a> { - pub subgraph_name: &'a str, - pub client_request: &'a ClientRequestDetails<'a>, - pub subgraph_headers: &'a HeaderMap, +pub struct ResponseExpressionContext<'exec, 'req> { + pub subgraph_name: &'exec str, + pub client_request: &'exec ClientRequestDetails<'exec, 'req>, + pub subgraph_headers: &'exec HeaderMap, } trait ApplyResponseHeader { diff --git a/lib/executor/src/plugins/examples/apq.rs b/lib/executor/src/plugins/examples/apq.rs index f5e380973..91a473b5e 100644 --- a/lib/executor/src/plugins/examples/apq.rs +++ b/lib/executor/src/plugins/examples/apq.rs @@ -10,11 +10,12 @@ pub struct APQPlugin { cache: DashMap, } +#[async_trait::async_trait] impl RouterPlugin for APQPlugin { - fn on_graphql_params<'exec>( + async fn on_graphql_params<'exec>( &'exec self, - payload: OnGraphQLParamsStartPayload, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { payload.on_end(|mut payload| { let persisted_query_ext = payload .graphql_params diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index fc6276643..ee92a029d 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -28,8 +28,9 @@ impl ResponseCachePlugin { } } +#[async_trait::async_trait] impl RouterPlugin for ResponseCachePlugin { - fn on_execute<'exec>( + async fn on_execute<'exec>( &'exec self, payload: OnExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { diff --git a/lib/executor/src/plugins/examples/subgraph_response_cache.rs b/lib/executor/src/plugins/examples/subgraph_response_cache.rs index 4d192dd39..4e4b36666 100644 --- a/lib/executor/src/plugins/examples/subgraph_response_cache.rs +++ b/lib/executor/src/plugins/examples/subgraph_response_cache.rs @@ -10,8 +10,9 @@ pub struct SubgraphResponseCachePlugin { cache: DashMap, } +#[async_trait::async_trait] impl RouterPlugin for SubgraphResponseCachePlugin { - fn on_subgraph_execute<'exec>( + async fn on_subgraph_execute<'exec>( &'exec self, mut payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 9a77679b9..328a1a8fd 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,16 +1,16 @@ use std::collections::HashMap; -use std::sync::Arc; use hive_router_query_planner::planner::plan_nodes::QueryPlan; -use ntex::web::HttpRequest; +use crate::plugin_context::{PluginContext, RouterHttpRequest}; use crate::plugin_trait::{EndPayload, StartPayload}; use crate::response::graphql_error::GraphQLError; use crate::response::value::Value; pub struct OnExecuteStartPayload<'exec> { - pub router_http_request: HttpRequest, - pub query_plan: Arc, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub query_plan: &'exec QueryPlan, pub data: Value<'exec>, pub errors: Vec, diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index a9afabed1..d954f94e7 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -6,6 +6,8 @@ use ntex::util::Bytes; use serde::{de, Deserialize, Deserializer}; use sonic_rs::Value; +use crate::plugin_context::PluginContext; +use crate::plugin_context::RouterHttpRequest; use crate::plugin_trait::EndPayload; use crate::plugin_trait::StartPayload; @@ -91,13 +93,14 @@ impl<'de> Deserialize<'de> for GraphQLParams { } } -pub struct OnGraphQLParamsStartPayload { - pub router_http_request: ntex::web::HttpRequest, +pub struct OnGraphQLParamsStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub body: Bytes, pub graphql_params: Option, } -impl StartPayload for OnGraphQLParamsStartPayload {} +impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} pub struct OnGraphQLParamsEndPayload { pub graphql_params: GraphQLParams, diff --git a/lib/executor/src/plugins/hooks/on_graphql_parse.rs b/lib/executor/src/plugins/hooks/on_graphql_parse.rs index df9b4e480..fa29e3b9d 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_parse.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_parse.rs @@ -2,11 +2,13 @@ use graphql_tools::static_graphql::query::Document; use crate::{ hooks::on_graphql_params::GraphQLParams, + plugin_context::{PluginContext, RouterHttpRequest}, plugin_trait::{EndPayload, StartPayload}, }; pub struct OnGraphQLParseStartPayload<'exec> { - pub router_http_request: ntex::web::HttpRequest, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub graphql_params: &'exec GraphQLParams, pub document: Option, } diff --git a/lib/executor/src/plugins/hooks/on_graphql_validation.rs b/lib/executor/src/plugins/hooks/on_graphql_validation.rs index f6bb55004..c341a6a36 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_validation.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_validation.rs @@ -8,10 +8,14 @@ use graphql_tools::{ }; use hive_router_query_planner::state::supergraph_state::SchemaDocument; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::{ + plugin_context::{PluginContext, PluginManager, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnGraphQLValidationStartPayload<'exec> { - pub router_http_request: &'exec mut ntex::web::HttpRequest, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub schema: &'exec SchemaDocument, pub document: &'exec Document, default_validation_plan: &'exec ValidationPlan, @@ -23,13 +27,14 @@ impl<'exec> StartPayload for OnGraphQLValidationS impl<'exec> OnGraphQLValidationStartPayload<'exec> { pub fn new( - router_http_request: &'exec mut ntex::web::HttpRequest, + plugin_manager: &'exec PluginManager<'exec>, schema: &'exec SchemaDocument, document: &'exec Document, default_validation_plan: &'exec ValidationPlan, ) -> Self { OnGraphQLValidationStartPayload { - router_http_request, + router_http_request: &plugin_manager.router_http_request, + context: &plugin_manager.context, schema, document, default_validation_plan, diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index a7f6f6bb5..84a683948 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -1,16 +1,20 @@ -use ntex::{http::Response, web::HttpRequest}; +use ntex::web::{self, DefaultError}; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::{ + plugin_context::PluginContext, + plugin_trait::{EndPayload, StartPayload}, +}; -pub struct OnHttpRequestPayload<'exec> { - pub client_request: &'exec HttpRequest, +pub struct OnHttpRequestPayload<'req> { + pub router_http_request: web::WebRequest, + pub context: &'req PluginContext, + pub response: Option, } -impl<'exec> StartPayload> for OnHttpRequestPayload<'exec> {} +impl<'req> StartPayload for OnHttpRequestPayload<'req> {} -pub struct OnHttpResponse<'exec> { - pub router_http_request: &'exec HttpRequest, - pub response: &'exec mut Response, +pub struct OnHttpResponsePayload { + pub response: web::WebResponse, } -impl<'exec> EndPayload for OnHttpResponse<'exec> {} +impl EndPayload for OnHttpResponsePayload {} diff --git a/lib/executor/src/plugins/hooks/on_query_plan.rs b/lib/executor/src/plugins/hooks/on_query_plan.rs index fd2089ec6..9b2110fd7 100644 --- a/lib/executor/src/plugins/hooks/on_query_plan.rs +++ b/lib/executor/src/plugins/hooks/on_query_plan.rs @@ -5,10 +5,14 @@ use hive_router_query_planner::{ utils::cancellation::CancellationToken, }; -use crate::plugin_trait::{EndPayload, StartPayload}; +use crate::{ + plugin_context::{PluginContext, RouterHttpRequest}, + plugin_trait::{EndPayload, StartPayload}, +}; pub struct OnQueryPlanStartPayload<'exec> { - pub router_http_request: &'exec mut ntex::web::HttpRequest, + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, pub filtered_operation_for_plan: &'exec OperationDefinition, pub planner_override_context: PlannerOverrideContext, pub cancellation_token: &'exec CancellationToken, diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 5a6fcc6a6..25ec2c30b 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -1,9 +1,13 @@ use crate::{ executors::common::{HttpExecutionResponse, SubgraphExecutionRequest}, + plugin_context::{PluginContext, RouterHttpRequest}, plugin_trait::{EndPayload, StartPayload}, }; pub struct OnSubgraphExecuteStartPayload<'exec> { + pub router_http_request: &'exec RouterHttpRequest<'exec>, + pub context: &'exec PluginContext, + pub subgraph_name: String, pub execution_request: SubgraphExecutionRequest<'exec>, diff --git a/lib/executor/src/plugins/mod.rs b/lib/executor/src/plugins/mod.rs index 02490fb5e..3c24ff9f2 100644 --- a/lib/executor/src/plugins/mod.rs +++ b/lib/executor/src/plugins/mod.rs @@ -1,3 +1,4 @@ pub mod examples; pub mod hooks; +pub mod plugin_context; pub mod plugin_trait; diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs new file mode 100644 index 000000000..bce85a517 --- /dev/null +++ b/lib/executor/src/plugins/plugin_context.rs @@ -0,0 +1,43 @@ +use std::{ + any::{Any, TypeId}, + sync::Arc, +}; + +use dashmap::DashMap; +use http::Uri; +use ntex::router::Path; +use ntex_http::HeaderMap; + +use crate::plugin_trait::RouterPlugin; + +pub struct RouterHttpRequest<'exec> { + pub uri: &'exec Uri, + pub method: &'exec http::Method, + pub version: http::Version, + pub headers: &'exec HeaderMap, + pub path: &'exec str, + pub query_string: &'exec str, + pub match_info: &'exec Path, +} + +#[derive(Default)] +pub struct PluginContext { + inner: DashMap>, +} + +impl PluginContext { + pub fn insert(&self, value: T) { + self.inner.insert(TypeId::of::(), Arc::new(value)); + } + pub fn get(&self) -> Option> { + self.inner + .get(&TypeId::of::()) + .map(|v| v.clone().downcast::().ok().unwrap()) + } +} + +pub struct PluginManager<'req> { + pub plugins: Arc>>, + pub router_http_request: RouterHttpRequest<'req>, + pub context: Arc, +} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index c502ba087..0a3f6db32 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -5,7 +5,7 @@ use crate::hooks::on_graphql_parse::{OnGraphQLParseEndPayload, OnGraphQLParseSta use crate::hooks::on_graphql_validation::{ OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload, }; -use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponse}; +use crate::hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}; use crate::hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}; use crate::hooks::on_subgraph_execute::{ OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload, @@ -77,51 +77,52 @@ where } } +#[async_trait::async_trait] pub trait RouterPlugin { - fn on_http_request<'exec>( + fn on_http_request<'req>( &self, - start_payload: OnHttpRequestPayload<'exec>, - ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponse<'exec>> { + start_payload: OnHttpRequestPayload<'req>, + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { start_payload.cont() } - fn on_graphql_params<'exec>( + async fn on_graphql_params<'exec>( &'exec self, - start_payload: OnGraphQLParamsStartPayload, - ) -> HookResult<'exec, OnGraphQLParamsStartPayload, OnGraphQLParamsEndPayload> { + start_payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { start_payload.cont() } - fn on_graphql_parse<'exec>( + async fn on_graphql_parse<'exec>( &self, start_payload: OnGraphQLParseStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { start_payload.cont() } - fn on_graphql_validation<'exec>( + async fn on_graphql_validation<'exec>( &self, start_payload: OnGraphQLValidationStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> { start_payload.cont() } - fn on_query_plan<'exec>( + async fn on_query_plan<'exec>( &self, start_payload: OnQueryPlanStartPayload<'exec>, ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { start_payload.cont() } - fn on_execute<'exec>( + async fn on_execute<'exec>( &'exec self, start_payload: OnExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload<'exec>> { start_payload.cont() } - fn on_subgraph_execute<'exec>( + async fn on_subgraph_execute<'exec>( &'exec self, start_payload: OnSubgraphExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { start_payload.cont() } - fn on_subgraph_http_request<'exec>( + async fn on_subgraph_http_request<'exec>( &'exec self, start_payload: OnSubgraphHttpRequestPayload<'exec>, ) -> HookResult<'exec, OnSubgraphHttpRequestPayload<'exec>, OnSubgraphHttpResponsePayload> { From 419294e209b15bf58de435c5df24a1e2cb1f0cb3 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Thu, 20 Nov 2025 18:19:02 +0300 Subject: [PATCH 10/15] Improvements --- bin/router/src/lib.rs | 4 +- bin/router/src/pipeline/coerce_variables.rs | 3 +- bin/router/src/pipeline/execution.rs | 14 +++---- bin/router/src/pipeline/mod.rs | 42 +++++++++---------- bin/router/src/pipeline/normalize.rs | 2 +- bin/router/src/pipeline/parser.rs | 6 +-- .../src/pipeline/progressive_override.rs | 4 +- bin/router/src/pipeline/query_plan.rs | 6 +-- bin/router/src/pipeline/validation.rs | 4 +- lib/executor/src/execution/plan.rs | 9 ++-- 10 files changed, 43 insertions(+), 51 deletions(-) diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index dec90d730..5e5a0e353 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -62,8 +62,8 @@ async fn graphql_endpoint_handler( &req, body_bytes, supergraph, - app_state.get_ref().clone(), - schema_state.get_ref().clone(), + app_state.get_ref(), + schema_state.get_ref(), ) .await { diff --git a/bin/router/src/pipeline/coerce_variables.rs b/bin/router/src/pipeline/coerce_variables.rs index d10fbb6c4..ab5759b5e 100644 --- a/bin/router/src/pipeline/coerce_variables.rs +++ b/bin/router/src/pipeline/coerce_variables.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::Arc; use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams; use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData; @@ -23,7 +22,7 @@ pub fn coerce_request_variables( req: &HttpRequest, supergraph: &SupergraphData, graphql_params: &mut GraphQLParams, - normalized_operation: &Arc, + normalized_operation: &GraphQLNormalizationPayload, ) -> Result { if req.method() == Method::GET { if let Some(OperationKind::Mutation) = diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index e40fc02c3..fae429895 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::Arc; use crate::pipeline::coerce_variables::CoerceVariablesPayload; use crate::pipeline::error::PipelineErrorVariant; @@ -24,16 +23,17 @@ enum ExposeQueryPlanMode { DryRun, } +#[allow(clippy::too_many_arguments)] #[inline] -pub async fn execute_plan<'exec, 'req>( +pub async fn execute_plan( req: &HttpRequest, supergraph: &SupergraphData, - app_state: Arc, - normalized_payload: Arc, - query_plan_payload: Arc, + app_state: &RouterSharedState, + normalized_payload: &GraphQLNormalizationPayload, + query_plan_payload: &QueryPlan, variable_payload: &CoerceVariablesPayload, - client_request_details: &ClientRequestDetails<'exec, 'req>, - plugin_manager: PluginManager<'req>, + client_request_details: &ClientRequestDetails<'_, '_>, + plugin_manager: PluginManager<'_>, ) -> Result { let mut expose_query_plan = ExposeQueryPlanMode::No; diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 860a13c91..f05745265 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -63,8 +63,8 @@ pub async fn graphql_request_handler( req: &HttpRequest, body_bytes: Bytes, supergraph: &SupergraphData, - shared_state: Arc, - schema_state: Arc, + shared_state: &RouterSharedState, + schema_state: &SchemaState, ) -> Result { if req.method() == Method::GET && req.accepts_content_type(*TEXT_HTML_CONTENT_TYPE) { if shared_state.router_config.graphiql.enabled { @@ -139,14 +139,14 @@ pub async fn graphql_request_handler( #[inline] #[allow(clippy::await_holding_refcell_ref)] -pub async fn execute_pipeline<'req>( - req: &'req HttpRequest, +pub async fn execute_pipeline( + req: &HttpRequest, body: Bytes, supergraph: &SupergraphData, - shared_state: Arc, - schema_state: Arc, + shared_state: &RouterSharedState, + schema_state: &SchemaState, jwt_context: Option, - plugin_manager: PluginManager<'req>, + plugin_manager: PluginManager<'_>, ) -> Result { perform_csrf_prevention(req, &shared_state.router_config.csrf)?; @@ -195,7 +195,7 @@ pub async fn execute_pipeline<'req>( /* Handle on_deserialize hook in the plugins - END */ let parser_result = - parse_operation_with_cache(shared_state.clone(), &graphql_params, &plugin_manager).await?; + parse_operation_with_cache(shared_state, &graphql_params, &plugin_manager).await?; let parser_payload = match parser_result { ParseResult::Payload(payload) => payload, @@ -206,20 +206,16 @@ pub async fn execute_pipeline<'req>( validate_operation_with_cache( supergraph, - schema_state.clone(), - shared_state.clone(), + schema_state, + shared_state, &parser_payload, &plugin_manager, ) .await?; - let normalize_payload = normalize_request_with_cache( - supergraph, - schema_state.clone(), - &graphql_params, - &parser_payload, - ) - .await?; + let normalize_payload = + normalize_request_with_cache(supergraph, schema_state, &graphql_params, &parser_payload) + .await?; let variable_payload = coerce_request_variables(req, supergraph, &mut graphql_params, &normalize_payload)?; @@ -264,11 +260,11 @@ pub async fn execute_pipeline<'req>( let query_plan_result = plan_operation_with_cache( supergraph, - schema_state.clone(), - normalize_payload.clone(), + schema_state, + &normalize_payload, &progressive_override_ctx, &query_plan_cancellation_token, - shared_state.clone(), + shared_state, &plugin_manager, ) .await?; @@ -282,9 +278,9 @@ pub async fn execute_pipeline<'req>( let execution_result = execute_plan( req, supergraph, - shared_state.clone(), - normalize_payload.clone(), - query_plan_payload, + shared_state, + &normalize_payload, + &query_plan_payload, &variable_payload, &client_request_details, plugin_manager, diff --git a/bin/router/src/pipeline/normalize.rs b/bin/router/src/pipeline/normalize.rs index 54093d065..97cbb80ac 100644 --- a/bin/router/src/pipeline/normalize.rs +++ b/bin/router/src/pipeline/normalize.rs @@ -26,7 +26,7 @@ pub struct GraphQLNormalizationPayload { #[inline] pub async fn normalize_request_with_cache( supergraph: &SupergraphData, - schema_state: Arc, + schema_state: &SchemaState, graphql_params: &GraphQLParams, parser_payload: &GraphQLParserPayload, ) -> Result, PipelineErrorVariant> { diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 0a11ab2aa..aebbd6beb 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -29,10 +29,10 @@ pub enum ParseResult { } #[inline] -pub async fn parse_operation_with_cache<'req>( - app_state: Arc, +pub async fn parse_operation_with_cache( + app_state: &RouterSharedState, graphql_params: &GraphQLParams, - plugin_manager: &PluginManager<'req>, + plugin_manager: &PluginManager<'_>, ) -> Result { let cache_key = { let mut hasher = Xxh3::new(); diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..4743d8672 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -51,9 +51,9 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context<'exec, 'req>( +pub fn request_override_context( override_labels_evaluator: &OverrideLabelsEvaluator, - client_request_details: &ClientRequestDetails<'exec, 'req>, + client_request_details: &ClientRequestDetails<'_, '_>, ) -> Result { let active_flags = override_labels_evaluator.evaluate(client_request_details)?; diff --git a/bin/router/src/pipeline/query_plan.rs b/bin/router/src/pipeline/query_plan.rs index 156b8ef9f..d1f83ca7d 100644 --- a/bin/router/src/pipeline/query_plan.rs +++ b/bin/router/src/pipeline/query_plan.rs @@ -31,11 +31,11 @@ pub enum QueryPlanGetterError { #[inline] pub async fn plan_operation_with_cache<'req>( supergraph: &SupergraphData, - schema_state: Arc, - normalized_operation: Arc, + schema_state: &SchemaState, + normalized_operation: &GraphQLNormalizationPayload, request_override_context: &RequestOverrideContext, cancellation_token: &CancellationToken, - app_state: Arc, + app_state: &RouterSharedState, plugin_manager: &PluginManager<'req>, ) -> Result { let stable_override_context = diff --git a/bin/router/src/pipeline/validation.rs b/bin/router/src/pipeline/validation.rs index afe833656..92cb1eb6f 100644 --- a/bin/router/src/pipeline/validation.rs +++ b/bin/router/src/pipeline/validation.rs @@ -17,8 +17,8 @@ use tracing::{error, trace}; #[inline] pub async fn validate_operation_with_cache( supergraph: &SupergraphData, - schema_state: Arc, - app_state: Arc, + schema_state: &SchemaState, + app_state: &RouterSharedState, parser_payload: &GraphQLParserPayload, plugin_manager: &PluginManager<'_>, ) -> Result, PipelineErrorVariant> { diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index bbe5ba836..7d07f430e 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{BTreeSet, HashMap}, - sync::Arc, -}; +use std::collections::{BTreeSet, HashMap}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; @@ -56,7 +53,7 @@ use crate::{ pub struct QueryPlanExecutionContext<'exec, 'req> { pub plugin_manager: &'exec PluginManager<'exec>, - pub query_plan: Arc, + pub query_plan: &'exec QueryPlan, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, @@ -87,7 +84,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { let mut start_payload = OnExecuteStartPayload { router_http_request: &self.plugin_manager.router_http_request, context: &self.plugin_manager.context, - query_plan: &self.query_plan, + query_plan: self.query_plan, data: init_value, errors: Vec::new(), extensions: self.extensions.clone(), From 57bb0f06c5c2e2fbe16c37070553f81c29ab7293 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 16:07:22 +0300 Subject: [PATCH 11/15] Examples --- Cargo.lock | 21 +++ bin/router/src/pipeline/mod.rs | 5 +- bin/router/src/plugins/plugins_service.rs | 23 ++- lib/executor/Cargo.toml | 3 + lib/executor/src/execution/plan.rs | 2 + lib/executor/src/executors/common.rs | 1 + lib/executor/src/executors/http.rs | 5 + lib/executor/src/executors/map.rs | 8 +- .../src/plugins/examples/apollo_sandbox.rs | 155 +++++++++++++++++ .../src/plugins/examples/async_auth.rs | 113 +++++++++++++ .../src/plugins/examples/context_data.rs | 67 ++++++++ .../examples/forbid_anonymous_operations.rs | 55 ++++++ lib/executor/src/plugins/examples/mod.rs | 7 + .../src/plugins/examples/multipart.rs | 159 ++++++++++++++++++ .../plugins/examples/propagate_status_code.rs | 46 +++++ .../src/plugins/examples/response_cache.rs | 3 +- .../src/plugins/examples/root_field_limit.rs | 62 +++++++ .../src/plugins/hooks/on_graphql_params.rs | 7 +- .../src/plugins/hooks/on_http_request.rs | 5 +- .../src/plugins/hooks/on_subgraph_execute.rs | 10 +- lib/executor/src/plugins/plugin_context.rs | 109 +++++++++++- lib/executor/src/plugins/plugin_trait.rs | 4 +- 22 files changed, 842 insertions(+), 28 deletions(-) create mode 100644 lib/executor/src/plugins/examples/apollo_sandbox.rs create mode 100644 lib/executor/src/plugins/examples/async_auth.rs create mode 100644 lib/executor/src/plugins/examples/context_data.rs create mode 100644 lib/executor/src/plugins/examples/forbid_anonymous_operations.rs create mode 100644 lib/executor/src/plugins/examples/propagate_status_code.rs create mode 100644 lib/executor/src/plugins/examples/root_field_limit.rs diff --git a/Cargo.lock b/Cargo.lock index 52b09935b..43720d595 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2055,6 +2055,7 @@ dependencies = [ "criterion", "dashmap", "futures", + "futures-util", "graphql-parser", "graphql-tools", "hive-router-config", @@ -2067,11 +2068,13 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "multer", "ntex", "ntex-http", "ordered-float", "redis", "regex-automata", + "reqwest", "ryu", "serde", "sonic-rs", @@ -2802,6 +2805,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4349,6 +4362,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", + "futures-util", "h2", "http", "http-body", @@ -4360,6 +4374,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -5831,6 +5846,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index f05745265..4d2f264ee 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -176,7 +176,10 @@ pub async fn execute_pipeline( .expect("Failed to parse execution request") }); - let mut payload = OnGraphQLParamsEndPayload { graphql_params }; + let mut payload = OnGraphQLParamsEndPayload { + graphql_params, + context: &plugin_manager.context, + }; for deserialization_end_callback in deserialization_end_callbacks { let result = deserialization_end_callback(payload); payload = result.payload; diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs index 49204b662..2d0d70c50 100644 --- a/bin/router/src/plugins/plugins_service.rs +++ b/bin/router/src/plugins/plugins_service.rs @@ -1,11 +1,14 @@ use std::sync::Arc; use hive_router_plan_executor::{ + execution::plan::PlanExecutionOutput, hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, plugin_context::PluginContext, plugin_trait::ControlFlowResult, }; +use http::StatusCode; use ntex::{ + http::ResponseBuilder, service::{Service, ServiceCtx}, web::{self, DefaultError}, Middleware, @@ -52,11 +55,11 @@ where let mut start_payload = OnHttpRequestPayload { router_http_request: req, context: &plugin_context, - response: None, }; let mut on_end_callbacks = vec![]; + let mut early_response: Option = None; for plugin in plugins.iter() { let result = plugin.on_http_request(start_payload); start_payload = result.payload; @@ -67,18 +70,24 @@ where ControlFlowResult::OnEnd(callback) => { on_end_callbacks.push(callback); } - ControlFlowResult::EndResponse(_response) => { - // Short-circuit the request with the provided response - unimplemented!(); + ControlFlowResult::EndResponse(response) => { + early_response = Some(response); + break; } } } let req = start_payload.router_http_request; - let response = match start_payload.response { - Some(response) => response, - None => ctx.call(&self.service, req).await?, + let response = if let Some(early_response) = early_response { + let mut builder = ResponseBuilder::new(StatusCode::OK); + for (key, value) in early_response.headers.iter() { + builder.header(key, value); + } + let res = builder.body(early_response.body); + req.into_response(res) + } else { + ctx.call(&self.service, req).await? }; let mut end_payload = OnHttpResponsePayload { response }; diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 39d51ee7e..07aefd5b6 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -30,6 +30,7 @@ xxhash-rust = { workspace = true } tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } vrl = { workspace = true } +reqwest = { workspace = true, features = ["multipart"] } ahash = "0.8.12" regex-automata = "0.4.10" @@ -53,6 +54,8 @@ ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" redis = "0.32.7" +multer = "3.1.0" +futures-util = "0.3.31" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index 7d07f430e..b38dccbac 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -69,6 +69,7 @@ pub struct QueryPlanExecutionContext<'exec, 'req> { pub struct PlanExecutionOutput { pub body: Vec, pub headers: HeaderMap, + pub status: http::StatusCode, } impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { @@ -177,6 +178,7 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { Ok(PlanExecutionOutput { body, headers: response_headers, + status: http::StatusCode::OK, }) } } diff --git a/lib/executor/src/executors/common.rs b/lib/executor/src/executors/common.rs index 9044c062e..6b3c804b5 100644 --- a/lib/executor/src/executors/common.rs +++ b/lib/executor/src/executors/common.rs @@ -49,4 +49,5 @@ impl SubgraphExecutionRequest<'_> { pub struct HttpExecutionResponse { pub body: Bytes, pub headers: HeaderMap, + pub status: http::StatusCode, } diff --git a/lib/executor/src/executors/http.rs b/lib/executor/src/executors/http.rs index 9e2770965..f33cbaedc 100644 --- a/lib/executor/src/executors/http.rs +++ b/lib/executor/src/executors/http.rs @@ -282,6 +282,7 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { return HttpExecutionResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, }; } }; @@ -311,12 +312,14 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body, headers: shared_response.headers, + status: shared_response.status, }, Err(e) => { self.log_error(&e); HttpExecutionResponse { body: self.error_to_graphql_bytes(e), headers: Default::default(), + status: StatusCode::OK, } } }; @@ -362,12 +365,14 @@ impl SubgraphExecutor for HTTPSubgraphExecutor { Ok(shared_response) => HttpExecutionResponse { body: shared_response.body.clone(), headers: shared_response.headers.clone(), + status: shared_response.status, }, Err(e) => { self.log_error(&e); HttpExecutionResponse { body: self.error_to_graphql_bytes(e.clone()), headers: Default::default(), + status: StatusCode::OK, } } } diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index cf6033040..bf8eb7838 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -154,6 +154,7 @@ impl SubgraphExecutorMap { return HttpExecutionResponse { body: response.body.into(), headers: response.headers, + status: response.status, }; } ControlFlowResult::OnEnd(callback) => { @@ -182,7 +183,10 @@ impl SubgraphExecutorMap { } }; - let mut end_payload = OnSubgraphExecuteEndPayload { execution_result }; + let mut end_payload = OnSubgraphExecuteEndPayload { + context: &plugin_manager.context, + execution_result, + }; for callback in on_end_callbacks { let result = callback(end_payload); @@ -196,6 +200,7 @@ impl SubgraphExecutorMap { return HttpExecutionResponse { body: response.body.into(), headers: response.headers, + status: response.status, }; } ControlFlowResult::OnEnd(_) => { @@ -222,6 +227,7 @@ impl SubgraphExecutorMap { HttpExecutionResponse { body: buffer.freeze(), headers: Default::default(), + status: http::StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs new file mode 100644 index 000000000..7c559bc8c --- /dev/null +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -0,0 +1,155 @@ +use ::serde::{Deserialize, Serialize}; +use ahash::HashMap; +use http::{HeaderMap, StatusCode}; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +#[derive(Default, Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxOptions { + /** + * The URL of the GraphQL endpoint that Sandbox introspects on initial load. Sandbox populates its pages using the schema obtained from this endpoint. + * The default value is `http://localhost:4000`. + * You should only pass non-production endpoints to Sandbox. Sandbox is powered by schema introspection, and we recommend [disabling introspection in production](https://www.apollographql.com/blog/graphql/security/why-you-should-disable-graphql-introspection-in-production/). + * To provide a "Sandbox-like" experience for production endpoints, we recommend using either a [public variant](https://www.apollographql.com/docs/graphos/platform/graph-management/variants#public-variants) or the [embedded Explorer](https://www.apollographql.com/docs/graphos/platform/explorer/embed). + */ + pub initial_endpoint: String, + /** + * By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.Set `hideCookieToggle` to `false` to enable users of your embedded Sandbox instance to toggle the **Include cookies** setting. + */ + pub hide_cookie_toggle: bool, + /** + * By default, the embedded Sandbox has a URL input box that is editable by users.Set endpointIsEditable to false to prevent users of your embedded Sandbox instance from changing the endpoint URL. + */ + pub endpoint_is_editable: bool, + /** + * You can set `includeCookies` to `true` if you instead want Sandbox to pass `{ credentials: 'include' }` for its requests.If you pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials).This config option is deprecated in favor of using the connection settings cookie toggle in Sandbox and setting the default value via `initialState.includeCookies`. + */ + pub include_cookies: bool, + /** + * An object containing additional options related to the state of the embedded Sandbox on page load. + */ + pub initial_state: ApolloSandboxInitialStateOptions, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "camelCase")] +pub struct ApolloSandboxInitialStateOptions { + /** + * Set this value to `true` if you want Sandbox to pass `{ credentials: 'include' }` for its requests by default.If you set `hideCookieToggle` to `false`, users can override this default setting with the **Include cookies** toggle. (By default, the embedded Sandbox does not show the **Include cookies** toggle in its connection settings.)If you also pass the `handleRequest` option, this option is ignored.Read more about the `fetch` API and credentials [here](https://developer.mozilla.org/en-US/docs/Web/API/fetch#credentials). + */ + pub include_cookies: bool, + /** + * A URI-encoded operation to populate in Sandbox's editor on load.If you omit this, Sandbox initially loads an example query based on your schema.Example: + * ```js + * initialState: { + * document: ` + * query ExampleQuery { + * books { + * title + * } + * } + * ` + * } + * ``` + */ + pub document: Option, + /** + * A URI-encoded, serialized object containing initial variable values to populate in Sandbox on load.If provided, these variables should apply to the initial query you provide for [`document`](https://www.apollographql.com/docs/apollo-sandbox#document).Example: + * + * ```js + * initialState: { + * variables: { + * userID: "abc123" + * }, + * } + * ``` + */ + pub variables: Option, + /** + * A URI-encoded, serialized object containing initial HTTP header values to populate in Sandbox on load.Example: + * + * + * ```js + * initialState: { + * headers: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub headers: Option, + /** + * The ID of a collection, paired with an operation ID to populate in Sandbox on load. You can find these values from a registered graph in Studio by clicking the **...** menu next to an operation in the Explorer of that graph and selecting **View operation details**.Example: + * + * ```js + * initialState: { + * collectionId: 'abc1234', + * operationId: 'xyz1234' + * } + * ``` + */ + pub collection_id: Option, + pub operation_id: Option, + /** + * If `true`, the embedded Sandbox periodically polls your `initialEndpoint` for schema updates.The default value is `true`.Example: + * + * ```js + * initialState: { + * pollForSchemaUpdates: false; + * } + * ``` + */ + pub poll_for_schema_updates: bool, + /** + * Headers that are applied by default to every operation executed by the embedded Sandbox. Users can turn off the application of these headers, but they can't modify their values.The embedded Sandbox always includes these headers in its introspection queries to your `initialEndpoint`.Example: + * + * ```js + * initialState: { + * sharedHeaders: { + * authorization: "Bearer abc123"; + * } + * } + * ``` + */ + pub shared_headers: HashMap, +} + +pub struct ApolloSandboxPlugin { + pub options: ApolloSandboxOptions, +} + +impl RouterPlugin for ApolloSandboxPlugin { + fn on_http_request<'req>( + &self, + payload: OnHttpRequestPayload<'req>, + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { + if payload.router_http_request.path() == "/apollo-sandbox" { + let config = sonic_rs::to_string(&self.options).unwrap_or_else(|_| "{}".to_string()); + let html = format!( + r#" +
+ + + "#, + config + ); + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/html".parse().unwrap()); + return payload.end_response(PlanExecutionOutput { + body: html.into_bytes(), + headers, + status: StatusCode::OK, + }); + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/async_auth.rs b/lib/executor/src/plugins/examples/async_auth.rs new file mode 100644 index 000000000..ce6d73331 --- /dev/null +++ b/lib/executor/src/plugins/examples/async_auth.rs @@ -0,0 +1,113 @@ +use std::path::PathBuf; + +// From https://github.com/apollographql/router/blob/dev/examples/async-auth/rust/src/allow_client_id_from_file.rs +use serde::Deserialize; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +#[derive(Deserialize)] +pub struct AllowClientIdConfig { + pub header: String, + pub path: String, +} + +pub struct AllowClientIdFromFile { + header_key: String, + allowed_ids_path: PathBuf, +} + +#[async_trait::async_trait] +impl RouterPlugin for AllowClientIdFromFile { + // Whenever it is a GraphQL request, + // We don't use on_http_request here because we want to run this only when it is a GraphQL request + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let header = payload.router_http_request.headers.get(&self.header_key); + match header { + Some(client_id) => { + let client_id_str = client_id.to_str(); + match client_id_str { + Ok(client_id) => { + let allowed_clients: Vec = sonic_rs::from_str( + std::fs::read_to_string(self.allowed_ids_path.clone()) + .unwrap() + .as_str(), + ) + .unwrap(); + + if !allowed_clients.contains(&client_id.to_string()) { + // Prepare an HTTP 403 response with a GraphQL error message + let body = json!( + { + "errors": [ + { + "message": "client-id is not allowed", + "extensions": { + "code": "UNAUTHORIZED_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::FORBIDDEN, + }); + } + } + Err(_not_a_string_error) => { + let message = format!("'{}' value is not a string", self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "BAD_CLIENT_ID" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + None => { + let message = format!("Missing '{}' header", self.header_key); + tracing::error!(message); + let body = json!( + { + "errors": [ + { + "message": message, + "extensions": { + "code": "AUTH_ERROR" + } + } + ] + } + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::UNAUTHORIZED, + }); + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/context_data.rs b/lib/executor/src/plugins/examples/context_data.rs new file mode 100644 index 000000000..38265fe57 --- /dev/null +++ b/lib/executor/src/plugins/examples/context_data.rs @@ -0,0 +1,67 @@ +// From https://github.com/apollographql/router/blob/dev/examples/context/rust/src/context_data.rs + +use crate::{ + hooks::{ + on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_context::PluginContextMutEntry, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct ContextDataPlugin {} + +pub struct ContextData { + incoming_data: String, + response_count: u64, +} + +#[async_trait::async_trait] +impl RouterPlugin for ContextDataPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let context_data = ContextData { + incoming_data: "world".to_string(), + response_count: 0, + }; + + payload.context.insert(context_data); + + payload.on_end(|payload| { + let mut ctx_data_entry = payload.context.get_mut_entry(); + let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); + if let Some(context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + let ctx_data_entry = payload.context.get_ref_entry(); + let context_data: Option<&ContextData> = ctx_data_entry.get_ref(); + if let Some(context_data) = context_data { + tracing::info!("hello {}", context_data.incoming_data); // Hello world! + let new_header_value = format!("Hello {}", context_data.incoming_data); + payload.execution_request.headers.insert( + "x-hello", + http::HeaderValue::from_str(&new_header_value).unwrap(), + ); + } + payload.on_end(|payload: OnSubgraphExecuteEndPayload<'exec>| { + let mut ctx_data_entry: PluginContextMutEntry = + payload.context.get_mut_entry(); + let context_data: Option<&mut ContextData> = ctx_data_entry.get_ref_mut(); + if let Some(context_data) = context_data { + context_data.response_count += 1; + tracing::info!("subrequest count {}", context_data.response_count); + } + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs new file mode 100644 index 000000000..a566d0e9d --- /dev/null +++ b/lib/executor/src/plugins/examples/forbid_anonymous_operations.rs @@ -0,0 +1,55 @@ +// Same with https://github.com/apollographql/router/blob/dev/examples/forbid-anonymous-operations/rust/src/forbid_anonymous_operations.rs + +use http::StatusCode; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_graphql_params::{OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +pub struct ForbidAnonymousOperations {} + +#[async_trait::async_trait] +impl RouterPlugin for ForbidAnonymousOperations { + async fn on_graphql_params<'exec>( + &'exec self, + payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + let maybe_operation_name = &payload + .graphql_params + .as_ref() + .and_then(|params| params.operation_name.as_ref()); + + if maybe_operation_name.is_none() + || maybe_operation_name + .expect("is_none() has been checked before; qed") + .is_empty() + { + // let's log the error + tracing::error!("Operation is not allowed!"); + + // Prepare an HTTP 400 response with a GraphQL error message + let response_body = json!({ + "errors": [ + { + "message": "Anonymous operations are not allowed", + "extensions": { + "code": "ANONYMOUS_OPERATION" + } + } + ] + }); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&response_body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: StatusCode::BAD_REQUEST, + }); + } else { + // we're good to go! + tracing::info!("operation is allowed!"); + return payload.cont(); + } + } +} diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index a6d766a9c..70ff81639 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -1,3 +1,10 @@ +pub mod apollo_sandbox; pub mod apq; +pub mod async_auth; +pub mod context_data; +pub mod forbid_anonymous_operations; +pub mod multipart; +pub mod propagate_status_code; pub mod response_cache; +pub mod root_field_limit; pub mod subgraph_response_cache; diff --git a/lib/executor/src/plugins/examples/multipart.rs b/lib/executor/src/plugins/examples/multipart.rs index e69de29bb..6a6162bc6 100644 --- a/lib/executor/src/plugins/examples/multipart.rs +++ b/lib/executor/src/plugins/examples/multipart.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; + +use crate::{ + executors::common::HttpExecutionResponse, + hooks::{ + on_graphql_params::{ + GraphQLParams, OnGraphQLParamsEndPayload, OnGraphQLParamsStartPayload, + }, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; +use bytes::Bytes; +use dashmap::DashMap; +use multer::Multipart; +use serde::Serialize; +pub struct MultipartPlugin {} + +pub struct MultipartFile { + pub filename: Option, + pub content_type: Option, + pub content: Bytes, +} + +pub struct MultipartContext { + pub file_map: HashMap>, + pub files: DashMap, +} + +#[derive(Serialize)] +struct MultipartOperations<'a> { + pub query: &'a str, + pub variables: Option<&'a HashMap<&'a str, &'a sonic_rs::Value>>, + pub operation_name: Option<&'a str>, +} + +#[async_trait::async_trait] +impl RouterPlugin for MultipartPlugin { + async fn on_graphql_params<'exec>( + &'exec self, + mut payload: OnGraphQLParamsStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLParamsStartPayload<'exec>, OnGraphQLParamsEndPayload> { + if let Some(content_type) = payload.router_http_request.headers.get("content-type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.starts_with("multipart/form-data") { + let boundary = multer::parse_boundary(content_type_str).unwrap(); + let body = payload.body.clone(); + let stream = futures_util::stream::once(async move { + Ok::(Bytes::from(body.to_vec())) + }); + let mut multipart = Multipart::new(stream, boundary); + while let Some(field) = multipart.next_field().await.unwrap() { + let field_name = field.name().unwrap().to_string(); + let filename = field.file_name().map(|s| s.to_string()); + let content_type = field.content_type().map(|s| s.to_string()); + let data = field.bytes().await.unwrap(); + match field_name.as_str() { + "operations" => { + let graphql_params: GraphQLParams = + sonic_rs::from_slice(&data).unwrap(); + payload.graphql_params = Some(graphql_params); + } + "map" => { + let file_map: HashMap> = + sonic_rs::from_slice(&data).unwrap(); + payload.context.insert(MultipartContext { + file_map, + files: DashMap::new(), + }); + } + field_name => { + let mut ctx_entry = payload.context.get_mut_entry(); + let multipart_ctx: Option<&mut MultipartContext> = + ctx_entry.get_ref_mut(); + if let Some(multipart_ctx) = multipart_ctx { + let multipart_file = MultipartFile { + filename, + content_type, + content: data, + }; + multipart_ctx + .files + .insert(field_name.to_string(), multipart_file); + } + } + } + } + } + } + } + payload.cont() + } + + async fn on_subgraph_execute<'exec>( + &'exec self, + mut payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload> { + if let Some(variables) = &payload.execution_request.variables { + let ctx_ref = payload.context.get_ref_entry(); + let multipart_ctx: Option<&MultipartContext> = ctx_ref.get_ref(); + if let Some(multipart_ctx) = multipart_ctx { + let mut file_map: HashMap> = HashMap::new(); + for variable_name in variables.keys() { + // Matching variables that are file references + for (files_ref, op_refs) in &multipart_ctx.file_map { + for op_ref in op_refs { + if op_ref.starts_with(format!("variables.{}", variable_name).as_str()) { + let op_refs_in_curr_map = + file_map.entry(files_ref.to_string()).or_default(); + op_refs_in_curr_map.push(op_ref.to_string()); + } + } + } + } + if !file_map.is_empty() { + let mut form = reqwest::multipart::Form::new(); + let operations_struct = MultipartOperations { + query: payload.execution_request.query, + variables: payload.execution_request.variables.as_ref(), + operation_name: payload.execution_request.operation_name, + }; + let operations = sonic_rs::to_string(&operations_struct).unwrap(); + form = form.text("operations", operations); + let file_map_str: String = sonic_rs::to_string(&file_map).unwrap(); + form = form.text("map", file_map_str); + for (file_ref, _op_refs) in file_map { + if let Some(file_field) = multipart_ctx.files.get(&file_ref) { + let mut part = + reqwest::multipart::Part::bytes(file_field.content.to_vec()); + if let Some(file_name) = &file_field.filename { + part = part.file_name(file_name.to_string()); + } + if let Some(content_type) = &file_field.content_type { + part = part.mime_str(&content_type.to_string()).unwrap(); + } + form = form.part(file_ref, part); + } + } + let resp = reqwest::Client::new() + .post("http://example.com/graphql") + // Using query as endpoint URL + .multipart(form) + .send() + .await + .unwrap(); + let headers = resp.headers().clone(); + let status = resp.status(); + let body = resp.bytes().await.unwrap(); + payload.execution_result = Some(HttpExecutionResponse { + body, + headers, + status, + }); + } + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs new file mode 100644 index 000000000..1e1b28d9d --- /dev/null +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -0,0 +1,46 @@ +// From https://github.com/apollographql/router/blob/dev/examples/status-code-propagation/rust/src/propagate_status_code.rs + +use http::StatusCode; + +use crate::{ + hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, +}; + +pub struct PropagateStatusCodePlugin { + pub status_codes: Vec, +} + +pub struct PropagateStatusCodeCtx { + pub status_code: StatusCode, +} + +#[async_trait::async_trait] +impl RouterPlugin for PropagateStatusCodePlugin { + async fn on_subgraph_execute<'exec>( + &'exec self, + payload: OnSubgraphExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnSubgraphExecuteStartPayload<'exec>, OnSubgraphExecuteEndPayload<'exec>> + { + payload.on_end(|payload| { + let status_code = payload.execution_result.status; + // if a response contains a status code we're watching... + if self.status_codes.contains(&status_code) { + // Checking if there is already a context entry + let mut ctx_entry = payload.context.get_mut_entry(); + let ctx: Option<&mut PropagateStatusCodeCtx> = ctx_entry.get_ref_mut(); + if let Some(ctx) = ctx { + // Update the status code if the new one is more severe (higher) + if status_code.as_u16() > ctx.status_code.as_u16() { + ctx.status_code = status_code; + } + } else { + // Insert a new context entry + let new_ctx = PropagateStatusCodeCtx { status_code }; + payload.context.insert(new_ctx); + } + } + payload.cont() + }) + } +} diff --git a/lib/executor/src/plugins/examples/response_cache.rs b/lib/executor/src/plugins/examples/response_cache.rs index ee92a029d..d144b07ec 100644 --- a/lib/executor/src/plugins/examples/response_cache.rs +++ b/lib/executor/src/plugins/examples/response_cache.rs @@ -1,5 +1,5 @@ use dashmap::DashMap; -use http::HeaderMap; +use http::{HeaderMap, StatusCode}; use redis::Commands; use crate::{ @@ -44,6 +44,7 @@ impl RouterPlugin for ResponseCachePlugin { return payload.end_response(PlanExecutionOutput { body: cached_response, headers: HeaderMap::new(), + status: StatusCode::OK, }); } return payload.on_end(move |mut payload: OnExecuteEndPayload<'exec>| { diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs new file mode 100644 index 000000000..0c57771c3 --- /dev/null +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -0,0 +1,62 @@ +use hive_router_query_planner::ast::selection_item::SelectionItem; +use sonic_rs::json; + +use crate::{ + execution::plan::PlanExecutionOutput, + hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; + +pub struct RootFieldLimitPlugin { + pub max_root_fields: usize, +} + +#[async_trait::async_trait] +impl RouterPlugin for RootFieldLimitPlugin { + async fn on_query_plan<'exec>( + &'exec self, + payload: OnQueryPlanStartPayload<'exec>, + ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { + let mut cnt = 0; + for selection in payload + .filtered_operation_for_plan + .selection_set + .items + .iter() + { + match selection { + SelectionItem::Field(_) => { + cnt += 1; + if cnt > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + cnt, self.max_root_fields + ); + tracing::warn!("{}", err_msg); + let body = json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_ROOT_FIELDS" + } + }] + }); + // Return error + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&body).unwrap_or_default(), + headers: http::HeaderMap::new(), + status: http::StatusCode::PAYLOAD_TOO_LARGE, + }); + } + } + SelectionItem::InlineFragment(_) => { + unreachable!("Inline fragments should have been inlined before query planning"); + } + SelectionItem::FragmentSpread(_) => { + unreachable!("Fragment spreads should have been inlined before query planning"); + } + } + } + payload.cont() + } +} diff --git a/lib/executor/src/plugins/hooks/on_graphql_params.rs b/lib/executor/src/plugins/hooks/on_graphql_params.rs index d954f94e7..c69d094e4 100644 --- a/lib/executor/src/plugins/hooks/on_graphql_params.rs +++ b/lib/executor/src/plugins/hooks/on_graphql_params.rs @@ -100,10 +100,11 @@ pub struct OnGraphQLParamsStartPayload<'exec> { pub graphql_params: Option, } -impl<'exec> StartPayload for OnGraphQLParamsStartPayload<'exec> {} +impl<'exec> StartPayload> for OnGraphQLParamsStartPayload<'exec> {} -pub struct OnGraphQLParamsEndPayload { +pub struct OnGraphQLParamsEndPayload<'exec> { pub graphql_params: GraphQLParams, + pub context: &'exec PluginContext, } -impl EndPayload for OnGraphQLParamsEndPayload {} +impl<'exec> EndPayload for OnGraphQLParamsEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 84a683948..9964425df 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -1,4 +1,4 @@ -use ntex::web::{self, DefaultError}; +use ntex::web::{self, DefaultError, WebRequest}; use crate::{ plugin_context::PluginContext, @@ -6,9 +6,8 @@ use crate::{ }; pub struct OnHttpRequestPayload<'req> { - pub router_http_request: web::WebRequest, + pub router_http_request: WebRequest, pub context: &'req PluginContext, - pub response: Option, } impl<'req> StartPayload for OnHttpRequestPayload<'req> {} diff --git a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs index 25ec2c30b..18d037c10 100644 --- a/lib/executor/src/plugins/hooks/on_subgraph_execute.rs +++ b/lib/executor/src/plugins/hooks/on_subgraph_execute.rs @@ -14,10 +14,14 @@ pub struct OnSubgraphExecuteStartPayload<'exec> { pub execution_result: Option, } -impl<'exec> StartPayload for OnSubgraphExecuteStartPayload<'exec> {} +impl<'exec> StartPayload> + for OnSubgraphExecuteStartPayload<'exec> +{ +} -pub struct OnSubgraphExecuteEndPayload { +pub struct OnSubgraphExecuteEndPayload<'exec> { pub execution_result: HttpExecutionResponse, + pub context: &'exec PluginContext, } -impl EndPayload for OnSubgraphExecuteEndPayload {} +impl<'exec> EndPayload for OnSubgraphExecuteEndPayload<'exec> {} diff --git a/lib/executor/src/plugins/plugin_context.rs b/lib/executor/src/plugins/plugin_context.rs index bce85a517..d5ea9421f 100644 --- a/lib/executor/src/plugins/plugin_context.rs +++ b/lib/executor/src/plugins/plugin_context.rs @@ -3,7 +3,10 @@ use std::{ sync::Arc, }; -use dashmap::DashMap; +use dashmap::{ + mapref::one::{Ref, RefMut}, + DashMap, +}; use http::Uri; use ntex::router::Path; use ntex_http::HeaderMap; @@ -22,17 +25,69 @@ pub struct RouterHttpRequest<'exec> { #[derive(Default)] pub struct PluginContext { - inner: DashMap>, + inner: DashMap>, +} + +pub struct PluginContextRefEntry<'a, T> { + pub entry: Option>>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> PluginContextRefEntry<'a, T> { + pub fn get_ref(&self) -> Option<&T> { + match &self.entry { + None => None, + Some(entry) => { + let boxed_any = entry.value(); + Some(boxed_any.downcast_ref::()?) + } + } + } +} +pub struct PluginContextMutEntry<'a, T> { + pub entry: Option>>, + phantom: std::marker::PhantomData, +} + +impl<'a, T: Any + Send + Sync> PluginContextMutEntry<'a, T> { + pub fn get_ref_mut(&mut self) -> Option<&mut T> { + match &mut self.entry { + None => None, + Some(entry) => { + let boxed_any = entry.value_mut(); + Some(boxed_any.downcast_mut::()?) + } + } + } } impl PluginContext { - pub fn insert(&self, value: T) { - self.inner.insert(TypeId::of::(), Arc::new(value)); + pub fn contains(&self) -> bool { + let type_id = TypeId::of::(); + self.inner.contains_key(&type_id) } - pub fn get(&self) -> Option> { + pub fn insert(&self, value: T) -> Option> { + let type_id = TypeId::of::(); self.inner - .get(&TypeId::of::()) - .map(|v| v.clone().downcast::().ok().unwrap()) + .insert(type_id, Box::new(value)) + .and_then(|boxed_any| boxed_any.downcast::().ok()) + } + pub fn get_ref_entry(&self) -> PluginContextRefEntry<'_, T> { + let type_id = TypeId::of::(); + let entry = self.inner.get(&type_id); + PluginContextRefEntry { + entry, + phantom: std::marker::PhantomData, + } + } + pub fn get_mut_entry<'a, T: Any + Send + Sync>(&'a self) -> PluginContextMutEntry<'a, T> { + let type_id = TypeId::of::(); + let entry = self.inner.get_mut(&type_id); + + PluginContextMutEntry { + entry, + phantom: std::marker::PhantomData, + } } } @@ -41,3 +96,43 @@ pub struct PluginManager<'req> { pub router_http_request: RouterHttpRequest<'req>, pub context: Arc, } + +#[cfg(test)] +mod tests { + #[test] + fn inserts_and_gets_immut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + let entry = ctx.get_ref_entry(); + let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 42); + } + #[test] + fn inserts_and_mutates_with_mut_ref() { + use super::PluginContext; + + struct TestCtx { + pub value: u32, + } + + let ctx = PluginContext::default(); + ctx.insert(TestCtx { value: 42 }); + + { + let mut entry = ctx.get_mut_entry(); + let ctx_mut: &mut TestCtx = entry.get_ref_mut().unwrap(); + ctx_mut.value = 100; + } + + let entry = ctx.get_ref_entry(); + let ctx_ref: &TestCtx = entry.get_ref().unwrap(); + assert_eq!(ctx_ref.value, 100); + } +} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 0a3f6db32..c946bc0ed 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -92,7 +92,7 @@ pub trait RouterPlugin { start_payload.cont() } async fn on_graphql_parse<'exec>( - &self, + &'exec self, start_payload: OnGraphQLParseStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLParseStartPayload<'exec>, OnGraphQLParseEndPayload> { start_payload.cont() @@ -105,7 +105,7 @@ pub trait RouterPlugin { start_payload.cont() } async fn on_query_plan<'exec>( - &self, + &'exec self, start_payload: OnQueryPlanStartPayload<'exec>, ) -> HookResult<'exec, OnQueryPlanStartPayload<'exec>, OnQueryPlanEndPayload> { start_payload.cont() From 1ac23c24b364d84c68ab58d6c2ec8cead94e85b4 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:06:56 +0300 Subject: [PATCH 12/15] Add oneOf --- bin/router/src/pipeline/execution.rs | 1 + lib/executor/src/execution/plan.rs | 12 +- lib/executor/src/plugins/examples/mod.rs | 1 + lib/executor/src/plugins/examples/one_of.rs | 134 ++++++++++++++++++ .../src/plugins/examples/root_field_limit.rs | 91 +++++++++++- lib/executor/src/plugins/hooks/on_execute.rs | 2 + lib/executor/src/plugins/plugin_trait.rs | 2 +- 7 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 lib/executor/src/plugins/examples/one_of.rs diff --git a/bin/router/src/pipeline/execution.rs b/bin/router/src/pipeline/execution.rs index fae429895..991f503c1 100644 --- a/bin/router/src/pipeline/execution.rs +++ b/bin/router/src/pipeline/execution.rs @@ -88,6 +88,7 @@ pub async fn execute_plan( let ctx = QueryPlanExecutionContext { plugin_manager: &plugin_manager, query_plan: query_plan_payload, + operation_for_plan: &normalized_payload.operation_for_plan, projection_plan: &normalized_payload.projection_plan, headers_plan: &app_state.headers_plan, variable_values: &variable_payload.variables_map, diff --git a/lib/executor/src/execution/plan.rs b/lib/executor/src/execution/plan.rs index b38dccbac..ad1a46744 100644 --- a/lib/executor/src/execution/plan.rs +++ b/lib/executor/src/execution/plan.rs @@ -2,9 +2,12 @@ use std::collections::{BTreeSet, HashMap}; use bytes::{BufMut, Bytes}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use hive_router_query_planner::planner::plan_nodes::{ - ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode, - QueryPlan, SequenceNode, +use hive_router_query_planner::{ + ast::operation::OperationDefinition, + planner::plan_nodes::{ + ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, + PlanNode, QueryPlan, SequenceNode, + }, }; use http::HeaderMap; use serde::Deserialize; @@ -54,6 +57,7 @@ use crate::{ pub struct QueryPlanExecutionContext<'exec, 'req> { pub plugin_manager: &'exec PluginManager<'exec>, pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, pub projection_plan: &'exec Vec, pub headers_plan: &'exec HeaderRulesPlan, pub variable_values: &'exec Option>, @@ -81,11 +85,11 @@ impl<'exec, 'req> QueryPlanExecutionContext<'exec, 'req> { }; let dedupe_subgraph_requests = self.operation_type_name == "Query"; - let mut start_payload = OnExecuteStartPayload { router_http_request: &self.plugin_manager.router_http_request, context: &self.plugin_manager.context, query_plan: self.query_plan, + operation_for_plan: self.operation_for_plan, data: init_value, errors: Vec::new(), extensions: self.extensions.clone(), diff --git a/lib/executor/src/plugins/examples/mod.rs b/lib/executor/src/plugins/examples/mod.rs index 70ff81639..eee9d7f7a 100644 --- a/lib/executor/src/plugins/examples/mod.rs +++ b/lib/executor/src/plugins/examples/mod.rs @@ -4,6 +4,7 @@ pub mod async_auth; pub mod context_data; pub mod forbid_anonymous_operations; pub mod multipart; +pub mod one_of; pub mod propagate_status_code; pub mod response_cache; pub mod root_field_limit; diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs new file mode 100644 index 000000000..90ced2d07 --- /dev/null +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -0,0 +1,134 @@ +// This example will show `@oneOf` input type validation in two steps: +// 1. During validation step +// 2. During execution step + +// We handle execution too to validate input objects at runtime as well. + +use std::{collections::BTreeMap, sync::RwLock}; + +use crate::{ + hooks::{ + on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_supergraph_load::{OnSupergraphLoadEndPayload, OnSupergraphLoadStartPayload}, + }, + plugin_trait::{HookResult, RouterPlugin, StartPayload}, +}; +use graphql_parser::{ + query::Value, + schema::{Definition, TypeDefinition}, +}; +use graphql_tools::ast::visit_document; +use graphql_tools::{ + ast::{OperationVisitor, OperationVisitorContext}, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; + +pub struct OneOfPlugin { + pub one_of_types: RwLock>, +} + +#[async_trait::async_trait] +impl RouterPlugin for OneOfPlugin { + // 1. During validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + let rule = OneOfValidationRule { + one_of_types: self.one_of_types.read().unwrap().clone(), + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // 2. During execution step + async fn on_execute<'exec>( + &'exec self, + payload: OnExecuteStartPayload<'exec>, + ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload> { + payload.cont() + } + fn on_supergraph_reload<'exec>( + &'exec self, + start_payload: OnSupergraphLoadStartPayload, + ) -> HookResult<'exec, OnSupergraphLoadStartPayload, OnSupergraphLoadEndPayload> { + for def in start_payload.new_ast.definitions.iter() { + if let Definition::TypeDefinition(TypeDefinition::InputObject(input_obj)) = def { + for directive in input_obj.directives.iter() { + if directive.name == "oneOf" { + self.one_of_types + .write() + .unwrap() + .push(input_obj.name.clone()); + } + } + } + } + start_payload.cont() + } +} + +struct OneOfValidationRule { + one_of_types: Vec, +} + +impl ValidationRule for OneOfValidationRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + op_ctx: &mut OperationVisitorContext<'_>, + validation_error_context: &mut ValidationErrorContext, + ) { + visit_document( + &mut OneOfValidation { + one_of_types: self.one_of_types.clone(), + }, + op_ctx.operation, + op_ctx, + validation_error_context, + ); + } +} + +struct OneOfValidation { + one_of_types: Vec, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for OneOfValidation { + fn enter_object_value( + &mut self, + visitor_context: &mut OperationVisitorContext<'a>, + user_context: &mut ValidationErrorContext, + fields: &BTreeMap, + ) { + if let Some(TypeDefinition::InputObject(input_type)) = visitor_context.current_input_type() + { + if self.one_of_types.contains(&input_type.name) { + let mut set_fields = vec![]; + for (field_name, field_value) in fields.iter() { + if !matches!(field_value, Value::Null) { + set_fields.push(field_name.clone()); + } + } + if set_fields.len() > 1 { + let err_msg = format!( + "Input object of type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + input_type.name, + set_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_FIELDS_SET_IN_ONEOF", + locations: vec![], + message: err_msg, + }); + } + } + } + } +} diff --git a/lib/executor/src/plugins/examples/root_field_limit.rs b/lib/executor/src/plugins/examples/root_field_limit.rs index 0c57771c3..24872e737 100644 --- a/lib/executor/src/plugins/examples/root_field_limit.rs +++ b/lib/executor/src/plugins/examples/root_field_limit.rs @@ -1,18 +1,42 @@ +use graphql_tools::{ + ast::{visit_document, OperationVisitor, OperationVisitorContext, TypeDefinitionExtension}, + static_graphql, + validation::{ + rules::ValidationRule, + utils::{ValidationError, ValidationErrorContext}, + }, +}; use hive_router_query_planner::ast::selection_item::SelectionItem; use sonic_rs::json; use crate::{ execution::plan::PlanExecutionOutput, - hooks::on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + hooks::{ + on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, + on_query_plan::{OnQueryPlanEndPayload, OnQueryPlanStartPayload}, + }, plugin_trait::{HookResult, RouterPlugin, StartPayload}, }; -pub struct RootFieldLimitPlugin { - pub max_root_fields: usize, -} +// This example shows two ways of limiting the number of root fields in a query: +// 1. During validation step +// 2. During query planning step #[async_trait::async_trait] impl RouterPlugin for RootFieldLimitPlugin { + // Using validation step + async fn on_graphql_validation<'exec>( + &'exec self, + mut payload: OnGraphQLValidationStartPayload<'exec>, + ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> + { + let rule = RootFieldLimitRule { + max_root_fields: self.max_root_fields, + }; + payload.add_validation_rule(Box::new(rule)); + payload.cont() + } + // Or during query planning async fn on_query_plan<'exec>( &'exec self, payload: OnQueryPlanStartPayload<'exec>, @@ -60,3 +84,62 @@ impl RouterPlugin for RootFieldLimitPlugin { payload.cont() } } + +pub struct RootFieldLimitPlugin { + max_root_fields: usize, +} + +pub struct RootFieldLimitRule { + max_root_fields: usize, +} + +struct RootFieldSelections { + max_root_fields: usize, + count: usize, +} + +impl<'a> OperationVisitor<'a, ValidationErrorContext> for RootFieldSelections { + fn enter_field( + &mut self, + visitor_context: &mut OperationVisitorContext, + user_context: &mut ValidationErrorContext, + field: &static_graphql::query::Field, + ) { + let parent_type_name = visitor_context.current_parent_type().map(|t| t.name()); + if parent_type_name == Some("Query") { + self.count += 1; + if self.count > self.max_root_fields { + let err_msg = format!( + "Query has too many root fields: {}, maximum allowed is {}", + self.count, self.max_root_fields + ); + user_context.report_error(ValidationError { + error_code: "TOO_MANY_ROOT_FIELDS", + locations: vec![field.position], + message: err_msg, + }); + } + } + } +} + +impl ValidationRule for RootFieldLimitRule { + fn error_code<'a>(&self) -> &'a str { + "TOO_MANY_ROOT_FIELDS" + } + fn validate( + &self, + ctx: &mut OperationVisitorContext<'_>, + error_collector: &mut ValidationErrorContext, + ) { + visit_document( + &mut RootFieldSelections { + max_root_fields: self.max_root_fields, + count: 0, + }, + ctx.operation, + ctx, + error_collector, + ); + } +} diff --git a/lib/executor/src/plugins/hooks/on_execute.rs b/lib/executor/src/plugins/hooks/on_execute.rs index 328a1a8fd..b69ba3297 100644 --- a/lib/executor/src/plugins/hooks/on_execute.rs +++ b/lib/executor/src/plugins/hooks/on_execute.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use hive_router_query_planner::ast::operation::OperationDefinition; use hive_router_query_planner::planner::plan_nodes::QueryPlan; use crate::plugin_context::{PluginContext, RouterHttpRequest}; @@ -11,6 +12,7 @@ pub struct OnExecuteStartPayload<'exec> { pub router_http_request: &'exec RouterHttpRequest<'exec>, pub context: &'exec PluginContext, pub query_plan: &'exec QueryPlan, + pub operation_for_plan: &'exec OperationDefinition, pub data: Value<'exec>, pub errors: Vec, diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index c946bc0ed..5a3e7c242 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -98,7 +98,7 @@ pub trait RouterPlugin { start_payload.cont() } async fn on_graphql_validation<'exec>( - &self, + &'exec self, start_payload: OnGraphQLValidationStartPayload<'exec>, ) -> HookResult<'exec, OnGraphQLValidationStartPayload<'exec>, OnGraphQLValidationEndPayload> { From f1a48ef6f333d9b65611268182999cda422c469f Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:07:06 +0300 Subject: [PATCH 13/15] Runtime error --- lib/executor/src/plugins/examples/one_of.rs | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs index 90ced2d07..601107af1 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -7,6 +7,7 @@ use std::{collections::BTreeMap, sync::RwLock}; use crate::{ + execution::plan::PlanExecutionOutput, hooks::{ on_execute::{OnExecuteEndPayload, OnExecuteStartPayload}, on_graphql_validation::{OnGraphQLValidationEndPayload, OnGraphQLValidationStartPayload}, @@ -26,6 +27,7 @@ use graphql_tools::{ utils::{ValidationError, ValidationErrorContext}, }, }; +use sonic_rs::{json, JsonContainerTrait}; pub struct OneOfPlugin { pub one_of_types: RwLock>, @@ -50,6 +52,42 @@ impl RouterPlugin for OneOfPlugin { &'exec self, payload: OnExecuteStartPayload<'exec>, ) -> HookResult<'exec, OnExecuteStartPayload<'exec>, OnExecuteEndPayload> { + if let (Some(variable_values), Some(variable_defs)) = ( + &payload.variable_values, + &payload.operation_for_plan.variable_definitions, + ) { + for def in variable_defs { + let variable_named_type = def.variable_type.inner_type(); + let one_of_types = self.one_of_types.read().unwrap(); + if one_of_types.contains(&variable_named_type.to_string()) { + let var_name = &def.name; + if let Some(value) = variable_values.get(var_name).and_then(|v| v.as_object()) { + let keys_num = value.len(); + if keys_num > 1 { + let err_msg = format!( + "Variable '${}' of input object type '{}' with @oneOf directive has multiple fields set: {:?}. Only one field must be set.", + var_name, + variable_named_type, + keys_num + ); + return payload.end_response(PlanExecutionOutput { + body: sonic_rs::to_vec(&json!({ + "errors": [{ + "message": err_msg, + "extensions": { + "code": "TOO_MANY_FIELDS_SET_IN_ONEOF" + } + }] + })) + .unwrap(), + headers: Default::default(), + status: http::StatusCode::BAD_REQUEST, + }); + } + } + } + } + } payload.cont() } fn on_supergraph_reload<'exec>( From 73c693ec3bc5ae7ae1ccf195ca09f7c392e379ef Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:10:24 +0300 Subject: [PATCH 14/15] Add description --- lib/executor/src/plugins/examples/one_of.rs | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/lib/executor/src/plugins/examples/one_of.rs b/lib/executor/src/plugins/examples/one_of.rs index 601107af1..738ece25a 100644 --- a/lib/executor/src/plugins/examples/one_of.rs +++ b/lib/executor/src/plugins/examples/one_of.rs @@ -3,6 +3,47 @@ // 2. During execution step // We handle execution too to validate input objects at runtime as well. +/* + Let's say we have the following input type with `@oneOf` directive: + input PaymentMethod @oneOf { + creditCard: CreditCardInput + bankTransfer: BankTransferInput + paypal: PayPalInput + } + + During validation, if a variable of type `PaymentMethod` is provided with multiple fields set, + we will raise a validation error. + + ```graphql + mutation MakePayment { + makePayment(method: { + creditCard: { number: "1234", expiry: "12/24" }, + paypal: { email: "john@doe.com" } + }) { + success + } + } + ``` + + But since variables can be dynamic, we also validate during execution. If the input object has multiple fields set, + we return an error in the response. + + ```graphql + mutation MakePayment($method: PaymentMethod!) { + makePayment(method: $method) { + success + } + } + ``` + + with variables: + { + "method": { + "creditCard": { "number": "1234", "expiry": "12/24" }, + "paypal": { "email": "john@doe.com" } + } + } +*/ use std::{collections::BTreeMap, sync::RwLock}; From 0afbea593ca025eb0772f3f08f5a622225ad85e8 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 22 Nov 2025 17:34:55 +0300 Subject: [PATCH 15/15] Propagate status code --- bin/router/src/plugins/plugins_service.rs | 5 ++++- .../src/plugins/examples/apollo_sandbox.rs | 4 ++-- .../plugins/examples/propagate_status_code.rs | 20 ++++++++++++++++++- .../src/plugins/hooks/on_http_request.rs | 7 ++++--- lib/executor/src/plugins/plugin_trait.rs | 4 ++-- 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/bin/router/src/plugins/plugins_service.rs b/bin/router/src/plugins/plugins_service.rs index 2d0d70c50..223a702ee 100644 --- a/bin/router/src/plugins/plugins_service.rs +++ b/bin/router/src/plugins/plugins_service.rs @@ -90,7 +90,10 @@ where ctx.call(&self.service, req).await? }; - let mut end_payload = OnHttpResponsePayload { response }; + let mut end_payload = OnHttpResponsePayload { + response, + context: &plugin_context, + }; for callback in on_end_callbacks.into_iter().rev() { let result = callback(end_payload); diff --git a/lib/executor/src/plugins/examples/apollo_sandbox.rs b/lib/executor/src/plugins/examples/apollo_sandbox.rs index 7c559bc8c..224654bdb 100644 --- a/lib/executor/src/plugins/examples/apollo_sandbox.rs +++ b/lib/executor/src/plugins/examples/apollo_sandbox.rs @@ -125,9 +125,9 @@ pub struct ApolloSandboxPlugin { impl RouterPlugin for ApolloSandboxPlugin { fn on_http_request<'req>( - &self, + &'req self, payload: OnHttpRequestPayload<'req>, - ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { if payload.router_http_request.path() == "/apollo-sandbox" { let config = sonic_rs::to_string(&self.options).unwrap_or_else(|_| "{}".to_string()); let html = format!( diff --git a/lib/executor/src/plugins/examples/propagate_status_code.rs b/lib/executor/src/plugins/examples/propagate_status_code.rs index 1e1b28d9d..1d519904e 100644 --- a/lib/executor/src/plugins/examples/propagate_status_code.rs +++ b/lib/executor/src/plugins/examples/propagate_status_code.rs @@ -3,7 +3,10 @@ use http::StatusCode; use crate::{ - hooks::on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + hooks::{ + on_http_request::{OnHttpRequestPayload, OnHttpResponsePayload}, + on_subgraph_execute::{OnSubgraphExecuteEndPayload, OnSubgraphExecuteStartPayload}, + }, plugin_trait::{EndPayload, HookResult, RouterPlugin, StartPayload}, }; @@ -43,4 +46,19 @@ impl RouterPlugin for PropagateStatusCodePlugin { payload.cont() }) } + fn on_http_request<'exec>( + &'exec self, + payload: OnHttpRequestPayload<'exec>, + ) -> HookResult<'exec, OnHttpRequestPayload<'exec>, OnHttpResponsePayload<'exec>> { + payload.on_end(|mut payload| { + // Checking if there is a context entry + let ctx_entry = payload.context.get_ref_entry(); + let ctx: Option<&PropagateStatusCodeCtx> = ctx_entry.get_ref(); + if let Some(ctx) = ctx { + // Update the HTTP response status code + *payload.response.response_mut().status_mut() = ctx.status_code; + } + payload.cont() + }) + } } diff --git a/lib/executor/src/plugins/hooks/on_http_request.rs b/lib/executor/src/plugins/hooks/on_http_request.rs index 9964425df..e9e857a80 100644 --- a/lib/executor/src/plugins/hooks/on_http_request.rs +++ b/lib/executor/src/plugins/hooks/on_http_request.rs @@ -10,10 +10,11 @@ pub struct OnHttpRequestPayload<'req> { pub context: &'req PluginContext, } -impl<'req> StartPayload for OnHttpRequestPayload<'req> {} +impl<'req> StartPayload> for OnHttpRequestPayload<'req> {} -pub struct OnHttpResponsePayload { +pub struct OnHttpResponsePayload<'req> { pub response: web::WebResponse, + pub context: &'req PluginContext, } -impl EndPayload for OnHttpResponsePayload {} +impl<'req> EndPayload for OnHttpResponsePayload<'req> {} diff --git a/lib/executor/src/plugins/plugin_trait.rs b/lib/executor/src/plugins/plugin_trait.rs index 5a3e7c242..12292be11 100644 --- a/lib/executor/src/plugins/plugin_trait.rs +++ b/lib/executor/src/plugins/plugin_trait.rs @@ -80,9 +80,9 @@ where #[async_trait::async_trait] pub trait RouterPlugin { fn on_http_request<'req>( - &self, + &'req self, start_payload: OnHttpRequestPayload<'req>, - ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload> { + ) -> HookResult<'req, OnHttpRequestPayload<'req>, OnHttpResponsePayload<'req>> { start_payload.cont() } async fn on_graphql_params<'exec>(