diff --git a/.changeset/graphql_endpoint.md b/.changeset/graphql_endpoint.md new file mode 100644 index 000000000..ea85f1beb --- /dev/null +++ b/.changeset/graphql_endpoint.md @@ -0,0 +1,25 @@ +--- +router: patch +config: patch +--- + +# `graphql_endpoint` Configuration and `.request.path_params` in VRL + +- Adds support for configuring the GraphQL endpoint path via the `graphql_endpoint` configuration option. + +So you can have dynamic path params that can be used with VRL expressions. + +`path_params` are also added to `.request` context in VRL for more dynamic configurations. + +```yaml +http: + graphql_endpoint: /graphql/{document_id} +persisted_documents: + enabled: true + spec: + expression: .request.path_params.document_id +``` + +[Learn more about the `graphql_endpoint` configuration option in the documentation.](https://the-guild.dev/graphql/hive/docs/router/configuration/graphql_endpoint) + +[Learn more about the `.request.path_params` configuration option in the documentation.](https://the-guild.dev/graphql/hive/docs/router/configuration/expressions#request) \ No newline at end of file diff --git a/.changeset/persisted_documents.md b/.changeset/persisted_documents.md new file mode 100644 index 000000000..800195c82 --- /dev/null +++ b/.changeset/persisted_documents.md @@ -0,0 +1,15 @@ +--- +router: patch +config: patch +executor: patch +--- + +# Persisted Documents + +- Supports Hive's `documentId` spec, Relay's `doc_id` spec and Apollo's `extensions` based spec as options +- - It is also possible to use your own method to extract document ids using VRL expressions +- Hive Console and File sources are supported +- A flag to enable/disable arbitrary operations +- - A VRL Expression can also be used to decide this dynamically using headers or any other request details + +[Learn more about Persisted Documents in the documentation.](https://the-guild.dev/graphql/hive/docs/router/configuration/persisted_documents) \ No newline at end of file diff --git a/.changeset/shared_utilities_to_handle_vrl_expressions.md b/.changeset/shared_utilities_to_handle_vrl_expressions.md new file mode 100644 index 000000000..97c07aa3a --- /dev/null +++ b/.changeset/shared_utilities_to_handle_vrl_expressions.md @@ -0,0 +1,15 @@ +--- +default: minor +--- + +# Breaking + +Removed `pool_idle_timeout_seconds` from `traffic_shaping`, instead use `pool_idle_timeout` with duration format. + +```diff +traffic_shaping: +- pool_idle_timeout_seconds: 30 ++ pool_idle_timeout: 30s +``` + +#540 by @ardatan diff --git a/Cargo.lock b/Cargo.lock index f405ed749..82df58136 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2048,6 +2048,31 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hive-console-sdk" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "752a852d62a36b0492125778563012cf6f84ca4ac8a8e2566e85f8f0f9a4c345" +dependencies = [ + "anyhow", + "async-trait", + "axum-core", + "graphql-parser", + "graphql-tools", + "md5", + "moka", + "reqwest", + "reqwest-middleware", + "reqwest-retry", + "serde", + "serde_json", + "sha2", + "thiserror 2.0.17", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hive-router" version = "0.0.17" @@ -2058,6 +2083,7 @@ dependencies = [ "futures", "graphql-parser", "graphql-tools", + "hive-console-sdk", "hive-router-config", "hive-router-plan-executor", "hive-router-query-planner", @@ -2129,7 +2155,9 @@ dependencies = [ "indexmap 2.12.0", "insta", "itoa", + "ntex", "ntex-http", + "once_cell", "ordered-float", "regex-automata", "ryu", @@ -2295,6 +2323,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", + "webpki-roots", ] [[package]] @@ -2847,6 +2876,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.6" @@ -4404,7 +4439,9 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", + "futures-util", "h2", "http", "http-body", @@ -4437,6 +4474,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots", ] [[package]] @@ -6257,6 +6295,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2878ef029c47c6e8cf779119f20fcf52bde7ad42a731b2a304bc221df17571e" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/bin/router/Cargo.toml b/bin/router/Cargo.toml index 4969b0a46..7c3e76db7 100644 --- a/bin/router/Cargo.toml +++ b/bin/router/Cargo.toml @@ -53,3 +53,4 @@ tokio-util = "0.7.16" cookie = "0.18.1" regex-automata = "0.4.10" arc-swap = "1.7.1" +hive-console-sdk = "0.1.0" diff --git a/bin/router/src/lib.rs b/bin/router/src/lib.rs index 6a3f7f5c0..ce380488e 100644 --- a/bin/router/src/lib.rs +++ b/bin/router/src/lib.rs @@ -3,6 +3,7 @@ mod consts; mod http_utils; mod jwt; mod logger; +mod persisted_documents; mod pipeline; mod schema_state; mod shared_state; @@ -19,6 +20,7 @@ use crate::{ }, jwt::JwtAuthRuntime, logger::configure_logging, + persisted_documents::PersistedDocumentsLoader, pipeline::graphql_request_handler, }; @@ -88,7 +90,9 @@ pub async fn router_entrypoint() -> Result<(), Box> { web::App::new() .state(shared_state.clone()) .state(schema_state.clone()) - .configure(configure_ntex_app) + .configure(|service_config| { + configure_ntex_app(service_config, &shared_state.router_config); + }) .default_service(web::to(landing_page_handler)) }) .bind(addr)? @@ -111,17 +115,36 @@ pub async fn configure_app_from_config( false => None, }; + let persisted_docs = if router_config.persisted_documents.enabled { + Some(PersistedDocumentsLoader::try_new( + &router_config.persisted_documents, + )?) + } else { + None + }; + let router_config_arc = Arc::new(router_config); let schema_state = SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone()).await?; let schema_state_arc = Arc::new(schema_state); - let shared_state = Arc::new(RouterSharedState::new(router_config_arc, jwt_runtime)?); + let shared_state = Arc::new(RouterSharedState::new( + router_config_arc, + jwt_runtime, + persisted_docs, + )?); Ok((shared_state, schema_state_arc)) } -pub fn configure_ntex_app(cfg: &mut web::ServiceConfig) { - cfg.route("/graphql", web::to(graphql_endpoint_handler)) +pub fn configure_ntex_app( + service_config: &mut web::ServiceConfig, + router_config: &HiveRouterConfig, +) { + service_config + .route( + &router_config.http.graphql_endpoint, + web::to(graphql_endpoint_handler), + ) .route("/health", web::to(health_check_handler)) .route("/readiness", web::to(readiness_check_handler)); } diff --git a/bin/router/src/persisted_documents/expr_input_val.rs b/bin/router/src/persisted_documents/expr_input_val.rs new file mode 100644 index 000000000..d0e41330f --- /dev/null +++ b/bin/router/src/persisted_documents/expr_input_val.rs @@ -0,0 +1,89 @@ +use std::collections::BTreeMap; + +use hive_router_plan_executor::execution::client_request_details::{ + client_header_map_to_vrl_value, client_path_params_to_vrl_value, client_url_to_vrl_value, + JwtRequestDetails, +}; +use ntex::web::HttpRequest; +use sonic_rs::{JsonContainerTrait, JsonValueTrait}; +use vrl::core::Value as VrlValue; + +use crate::pipeline::execution_request::ExecutionRequest; + +pub fn get_expression_input_val( + execution_request: &ExecutionRequest, + req: &HttpRequest, + jwt_request_details: &JwtRequestDetails<'_>, +) -> VrlValue { + let headers_value = client_header_map_to_vrl_value(req.headers()); + let url_value = client_url_to_vrl_value(req.uri()); + let path_params_value = client_path_params_to_vrl_value(req.match_info()); + let request_obj = VrlValue::Object(BTreeMap::from([ + ("method".into(), req.method().as_str().into()), + ("headers".into(), headers_value), + ("url".into(), url_value), + ("path_params".into(), path_params_value), + ("jwt".into(), jwt_request_details.into()), + ( + "body".into(), + execution_request_to_vrl_value(execution_request), + ), + ])); + + VrlValue::Object(BTreeMap::from([("request".into(), request_obj)])) +} + +fn execution_request_to_vrl_value(execution_request: &ExecutionRequest) -> VrlValue { + let mut obj = BTreeMap::new(); + if let Some(op_name) = &execution_request.operation_name { + obj.insert("operationName".into(), op_name.clone().into()); + } + if let Some(query) = &execution_request.query { + obj.insert("query".into(), query.clone().into()); + } + for (k, v) in &execution_request.extra_params { + obj.insert(k.clone().into(), from_sonic_value_to_vrl_value(v)); + } + VrlValue::Object(obj) +} + +fn from_sonic_value_to_vrl_value(value: &sonic_rs::Value) -> VrlValue { + match value.get_type() { + sonic_rs::JsonType::Null => VrlValue::Null, + sonic_rs::JsonType::Boolean => VrlValue::Boolean(value.as_bool().unwrap_or(false)), + sonic_rs::JsonType::Number => { + if let Some(n) = value.as_i64() { + VrlValue::Integer(n) + } else if let Some(n) = value.as_f64() { + VrlValue::from_f64_or_zero(n) + } else { + VrlValue::Null + } + } + sonic_rs::JsonType::String => { + if let Some(s) = value.as_str() { + s.into() + } else { + VrlValue::Null + } + } + sonic_rs::JsonType::Array => { + if let Some(array) = value.as_array() { + let vec = array.iter().map(from_sonic_value_to_vrl_value).collect(); + VrlValue::Array(vec) + } else { + VrlValue::Null + } + } + sonic_rs::JsonType::Object => { + if let Some(obj) = value.as_object() { + obj.iter() + .map(|(k, v)| (k.into(), from_sonic_value_to_vrl_value(v))) + .collect::>() + .into() + } else { + VrlValue::Null + } + } + } +} diff --git a/bin/router/src/persisted_documents/fetcher/file.rs b/bin/router/src/persisted_documents/fetcher/file.rs new file mode 100644 index 000000000..edade124e --- /dev/null +++ b/bin/router/src/persisted_documents/fetcher/file.rs @@ -0,0 +1,28 @@ +use std::{collections::HashMap, fs::read_to_string}; + +use hive_router_config::primitives::file_path::FilePath; + +use crate::persisted_documents::PersistedDocumentsError; + +pub struct FilePersistedDocumentsManager { + operations: HashMap, +} + +impl FilePersistedDocumentsManager { + pub fn try_new(file_path: &FilePath) -> Result { + let content = + read_to_string(&file_path.absolute).map_err(PersistedDocumentsError::FileReadError)?; + + let operations: HashMap = + serde_json::from_str(&content).map_err(PersistedDocumentsError::ParseError)?; + + Ok(Self { operations }) + } + + pub fn resolve_document(&self, document_id: &str) -> Result { + match self.operations.get(document_id) { + Some(document) => Ok(document.clone()), + None => Err(PersistedDocumentsError::NotFound(document_id.to_string())), + } + } +} diff --git a/bin/router/src/persisted_documents/fetcher/mod.rs b/bin/router/src/persisted_documents/fetcher/mod.rs new file mode 100644 index 000000000..5c05e54e2 --- /dev/null +++ b/bin/router/src/persisted_documents/fetcher/mod.rs @@ -0,0 +1,72 @@ +use hive_console_sdk::persisted_documents::PersistedDocumentsManager; +use hive_router_config::persisted_documents::PersistedDocumentsSource; + +use crate::persisted_documents::{ + fetcher::file::FilePersistedDocumentsManager, PersistedDocumentsError, +}; + +mod file; +pub enum PersistedDocumentsFetcher { + File(FilePersistedDocumentsManager), + HiveConsole(PersistedDocumentsManager), +} + +impl PersistedDocumentsFetcher { + pub fn try_new(config: &PersistedDocumentsSource) -> Result { + match config { + PersistedDocumentsSource::File { path, .. } => { + let manager = FilePersistedDocumentsManager::try_new(path)?; + Ok(PersistedDocumentsFetcher::File(manager)) + } + PersistedDocumentsSource::HiveConsole { + endpoint, + key, + accept_invalid_certs, + request_timeout, + connect_timeout, + retry_count, + cache_size, + } => { + let manager = PersistedDocumentsManager::new( + key.clone(), + endpoint.clone(), + *accept_invalid_certs, + *connect_timeout, + *request_timeout, + *retry_count, + *cache_size, + ); + + Ok(PersistedDocumentsFetcher::HiveConsole(manager)) + } + } + } + pub async fn resolve(&self, document_id: &str) -> Result { + match self { + PersistedDocumentsFetcher::File(manager) => Ok(manager.resolve_document(document_id)?), + PersistedDocumentsFetcher::HiveConsole(manager) => { + Ok(manager.resolve_document(document_id).await?) + } + } + } +} + +impl From + for PersistedDocumentsError +{ + fn from( + orig_err: hive_console_sdk::persisted_documents::PersistedDocumentsError, + ) -> PersistedDocumentsError { + match orig_err { + hive_console_sdk::persisted_documents::PersistedDocumentsError::DocumentNotFound => PersistedDocumentsError::NotFound("unknown".to_string()), + hive_console_sdk::persisted_documents::PersistedDocumentsError::FailedToFetchFromCDN(e) => PersistedDocumentsError::NetworkError(e), + hive_console_sdk::persisted_documents::PersistedDocumentsError::PersistedDocumentRequired => PersistedDocumentsError::PersistedDocumentsOnly, + hive_console_sdk::persisted_documents::PersistedDocumentsError::FailedToParseBody(e) => PersistedDocumentsError::ParseError(e), + hive_console_sdk::persisted_documents::PersistedDocumentsError::KeyNotFound => PersistedDocumentsError::KeyNotFound, + hive_console_sdk::persisted_documents::PersistedDocumentsError::FailedToReadCDNResponse(e) => PersistedDocumentsError::NetworkError( + reqwest_middleware::Error::Reqwest(e), + ), + hive_console_sdk::persisted_documents::PersistedDocumentsError::FailedToReadBody(e) => PersistedDocumentsError::ReadError(e), + } + } +} diff --git a/bin/router/src/persisted_documents/mod.rs b/bin/router/src/persisted_documents/mod.rs new file mode 100644 index 000000000..c2105cc46 --- /dev/null +++ b/bin/router/src/persisted_documents/mod.rs @@ -0,0 +1,152 @@ +mod expr_input_val; +mod fetcher; +mod spec; + +use hive_router_config::persisted_documents::{BoolOrExpression, PersistedDocumentsConfig}; +use hive_router_plan_executor::{ + execution::client_request_details::JwtRequestDetails, + utils::expression::{compile_expression, execute_expression_with_value}, +}; +use ntex::web::HttpRequest; +use tracing::trace; +use vrl::{compiler::Program as VrlProgram, core::Value as VrlValue}; + +use crate::{ + persisted_documents::{ + expr_input_val::get_expression_input_val, fetcher::PersistedDocumentsFetcher, + spec::PersistedDocumentsSpecResolver, + }, + pipeline::execution_request::ExecutionRequest, +}; + +pub struct PersistedDocumentsLoader { + fetcher: PersistedDocumentsFetcher, + spec: PersistedDocumentsSpecResolver, + allow_arbitrary_operations: BoolOrProgram, +} + +pub enum BoolOrProgram { + Bool(bool), + Program(Box), +} + +pub fn compile_bool_or_expression( + bool_or_expr: &BoolOrExpression, +) -> Result { + match bool_or_expr { + BoolOrExpression::Bool(b) => Ok(BoolOrProgram::Bool(*b)), + BoolOrExpression::Expression { expression } => { + let program = compile_expression(expression, None).map_err(|err| { + PersistedDocumentsError::ArbitraryOpsExpressionBuild(expression.to_string(), err) + })?; + Ok(BoolOrProgram::Program(Box::new(program))) + } + } +} + +pub fn execute_bool_or_program( + program: &BoolOrProgram, + execution_request: &ExecutionRequest, + req: &HttpRequest, + jwt_request_details: &JwtRequestDetails<'_>, +) -> Result { + match program { + BoolOrProgram::Bool(b) => Ok(*b), + BoolOrProgram::Program(prog) => { + let input = get_expression_input_val(execution_request, req, jwt_request_details); + let output = execute_expression_with_value(prog, input).map_err(|e| { + PersistedDocumentsError::ArbitraryOpsExpressionExecute(e.to_string()) + })?; + + match output { + VrlValue::Boolean(b) => Ok(b), + _ => Err(PersistedDocumentsError::ArbitraryOpsExpressionExecute( + format!( + "Expected boolean output from allow arbitrary operations expression, got {:?}", + output + ), + )), + } + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PersistedDocumentsError { + #[error("Persisted document not found: {0}")] + NotFound(String), + #[error("Only persisted documents are allowed")] + PersistedDocumentsOnly, + #[error("Network error: {0}")] + NetworkError(reqwest_middleware::Error), + #[error("Failed to read persisted documents from file: {0}")] + FileReadError(std::io::Error), + #[error("Failed to parse persisted documents: {0}")] + ParseError(serde_json::Error), + #[error("Failed to compile VRL expression to extract the document id for the persisted documents '{0}'. Please check your VRL expression for syntax errors. Diagnostic: {1}")] + SpecExpressionBuild(String, String), + #[error("Failed to execute VRL expression to extract the document id for the persisted documents: {0}")] + SpecExpressionExecute(String), + #[error("Failed to compile VRL expression to decide to allow arbitrary operations '{0}'. Please check your VRL expression for syntax errors. Diagnostic: {1}")] + ArbitraryOpsExpressionBuild(String, String), + #[error("Failed to execute VRL expression to decide to allow arbitrary operations: {0}")] + ArbitraryOpsExpressionExecute(String), + #[error("Failed to read persisted document: {0}")] + ReadError(String), + #[error("Key not found in persisted documents request")] + KeyNotFound, +} + +impl PersistedDocumentsLoader { + pub fn try_new(config: &PersistedDocumentsConfig) -> Result { + let fetcher = PersistedDocumentsFetcher::try_new(&config.source)?; + + let spec = PersistedDocumentsSpecResolver::new(&config.spec)?; + + let allow_arbitrary_operations = + compile_bool_or_expression(&config.allow_arbitrary_operations)?; + + Ok(Self { + fetcher, + spec, + allow_arbitrary_operations, + }) + } + + pub async fn handle( + &self, + execution_request: &mut ExecutionRequest, + req: &HttpRequest, + jwt_request_details: &JwtRequestDetails<'_>, + ) -> Result<(), PersistedDocumentsError> { + if let Some(ref query) = &execution_request.query { + if !query.is_empty() { + trace!("arbitrary operation detected in request"); + let allow_arbitrary_operations = execute_bool_or_program( + &self.allow_arbitrary_operations, + execution_request, + req, + jwt_request_details, + )?; + // If arbitrary operations are not allowed, return an error. + if !allow_arbitrary_operations { + return Err(PersistedDocumentsError::PersistedDocumentsOnly); + // If they are allowed, skip fetching persisted document. + } else { + return Ok(()); + } + } + } + + trace!("extracting persisted document id from request"); + let document_id = + self.spec + .extract_document_id(execution_request, req, jwt_request_details)?; + trace!("fetching persisted document for id {}", document_id); + let query = self.fetcher.resolve(&document_id).await?; + trace!("persisted document fetched successfully {}", query); + execution_request.query = Some(query); + + Ok(()) + } +} diff --git a/bin/router/src/persisted_documents/spec/mod.rs b/bin/router/src/persisted_documents/spec/mod.rs new file mode 100644 index 000000000..fd2dd721d --- /dev/null +++ b/bin/router/src/persisted_documents/spec/mod.rs @@ -0,0 +1,77 @@ +use hive_router_config::persisted_documents::PersistedDocumentsSpec; +use hive_router_plan_executor::{ + execution::client_request_details::JwtRequestDetails, + utils::expression::{compile_expression, execute_expression_with_value}, +}; +use ntex::web::HttpRequest; +use sonic_rs::JsonValueTrait; +use vrl::{compiler::Program as VrlProgram, core::Value as VrlValue}; + +use crate::{ + persisted_documents::{expr_input_val::get_expression_input_val, PersistedDocumentsError}, + pipeline::execution_request::ExecutionRequest, +}; + +pub enum PersistedDocumentsSpecResolver { + Hive, + Apollo, + Relay, + Expression(Box), +} + +impl PersistedDocumentsSpecResolver { + pub fn new(spec: &PersistedDocumentsSpec) -> Result { + match spec { + PersistedDocumentsSpec::Hive => Ok(PersistedDocumentsSpecResolver::Hive), + PersistedDocumentsSpec::Apollo => Ok(PersistedDocumentsSpecResolver::Apollo), + PersistedDocumentsSpec::Relay => Ok(PersistedDocumentsSpecResolver::Relay), + PersistedDocumentsSpec::Expression(expression) => { + let program = compile_expression(expression, None).map_err(|err| { + PersistedDocumentsError::SpecExpressionBuild(expression.to_string(), err) + })?; + Ok(PersistedDocumentsSpecResolver::Expression(Box::new( + program, + ))) + } + } + } + pub fn extract_document_id( + &self, + execution_request: &ExecutionRequest, + req: &HttpRequest, + jwt_request_details: &JwtRequestDetails<'_>, + ) -> Result { + match &self { + PersistedDocumentsSpecResolver::Hive => execution_request + .extra_params + .get("documentId") + .and_then(|val| val.as_str().map(|s| s.to_string())) + .ok_or(PersistedDocumentsError::KeyNotFound), + PersistedDocumentsSpecResolver::Apollo => execution_request + .extensions + .get("persistedQuery") + .and_then(|val| val.get("sha256Hash")) + .and_then(|val| val.as_str().map(|s| s.to_string())) + .ok_or(PersistedDocumentsError::KeyNotFound), + PersistedDocumentsSpecResolver::Relay => execution_request + .extra_params + .get("doc_id") + .and_then(|s| s.as_str().map(|s| s.to_string())) + .ok_or(PersistedDocumentsError::KeyNotFound), + PersistedDocumentsSpecResolver::Expression(program) => { + let input = get_expression_input_val(execution_request, req, jwt_request_details); + + let output = execute_expression_with_value(program, input) + .map_err(|e| PersistedDocumentsError::SpecExpressionExecute(e.to_string()))?; + + match output { + VrlValue::Bytes(b) => Ok(String::from_utf8_lossy(&b).to_string()), + _ => Err(PersistedDocumentsError::SpecExpressionExecute(format!( + "Expected string output from persisted documents expression, got {:?}", + output + ))), + } + } + } + } +} diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index eec36ea76..d09b3410a 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -15,9 +15,12 @@ use ntex::{ }; use serde::{Deserialize, Serialize}; -use crate::pipeline::{ - header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, - progressive_override::LabelEvaluationError, +use crate::{ + persisted_documents::PersistedDocumentsError, + pipeline::{ + header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + progressive_override::LabelEvaluationError, + }, }; #[derive(Debug)] @@ -51,9 +54,7 @@ pub enum PipelineErrorVariant { UnsupportedContentType, // GET Specific pipeline errors - #[error("Failed to deserialize query parameters")] - GetInvalidQueryParams, - #[error("Missing query parameter: {0}")] + #[error("Missing query parameter")] GetMissingQueryParam(&'static str), #[error("Cannot perform mutations over GET")] MutationNotAllowedOverHttpGet, @@ -89,6 +90,10 @@ pub enum PipelineErrorVariant { // JWT-auth plugin errors #[error("Failed to forward jwt: {0}")] JwtForwardingError(JwtForwardingError), + + // Persisted Documents errors + #[error(transparent)] + PersistedDocumentsError(#[from] PersistedDocumentsError), } impl PipelineErrorVariant { @@ -110,6 +115,12 @@ impl PipelineErrorVariant { Self::NormalizationError(NormalizationError::MultipleMatchingOperationsFound) => { "OPERATION_RESOLUTION_FAILURE" } + Self::PersistedDocumentsError(err) => match err { + PersistedDocumentsError::NotFound(_) => "PERSISTED_QUERY_NOT_FOUND", + PersistedDocumentsError::KeyNotFound => "PERSISTED_QUERY_KEY_NOT_FOUND", + PersistedDocumentsError::PersistedDocumentsOnly => "PERSISTED_QUERY_ONLY", + _ => "PERSISTED_DOCUMENT_ERROR", + }, _ => "BAD_REQUEST", } } @@ -130,7 +141,6 @@ impl PipelineErrorVariant { (Self::UnsupportedHttpMethod(_), _) => StatusCode::METHOD_NOT_ALLOWED, (Self::InvalidHeaderValue(_), _) => StatusCode::BAD_REQUEST, (Self::GetUnprocessableQueryParams(_), _) => StatusCode::BAD_REQUEST, - (Self::GetInvalidQueryParams, _) => StatusCode::BAD_REQUEST, (Self::GetMissingQueryParam(_), _) => StatusCode::BAD_REQUEST, (Self::FailedToParseBody(_), _) => StatusCode::BAD_REQUEST, (Self::FailedToParseVariables(_), _) => StatusCode::BAD_REQUEST, @@ -146,6 +156,12 @@ impl PipelineErrorVariant { (Self::MissingContentTypeHeader, _) => StatusCode::NOT_ACCEPTABLE, (Self::UnsupportedContentType, _) => StatusCode::UNSUPPORTED_MEDIA_TYPE, (Self::CsrfPreventionFailed, _) => StatusCode::FORBIDDEN, + (Self::PersistedDocumentsError(err), _) => match err { + PersistedDocumentsError::NotFound(_) => StatusCode::NOT_FOUND, + PersistedDocumentsError::KeyNotFound => StatusCode::BAD_REQUEST, + PersistedDocumentsError::PersistedDocumentsOnly => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, } } } diff --git a/bin/router/src/pipeline/execution_request.rs b/bin/router/src/pipeline/execution_request.rs index c17a6f355..a30b64cbd 100644 --- a/bin/router/src/pipeline/execution_request.rs +++ b/bin/router/src/pipeline/execution_request.rs @@ -1,55 +1,114 @@ +use core::fmt; use std::collections::HashMap; use http::Method; use ntex::util::Bytes; use ntex::web::types::Query; use ntex::web::HttpRequest; -use serde::{Deserialize, Deserializer}; +use serde::{de, Deserialize, Deserializer}; use sonic_rs::Value; use tracing::{trace, warn}; use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}; use crate::pipeline::header::AssertRequestJson; -#[derive(serde::Deserialize, Debug)] +#[derive(serde::Deserialize, Debug, Default)] struct GETQueryParams { pub query: Option, #[serde(rename = "camelCase")] pub operation_name: Option, pub variables: Option, pub extensions: Option, + #[serde(flatten)] + pub extra_params: HashMap, } -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] +#[derive(Debug, Clone, Default)] pub struct ExecutionRequest { - pub query: String, + pub query: Option, pub operation_name: Option, - #[serde(default, deserialize_with = "deserialize_null_default")] pub variables: HashMap, - // TODO: We don't use extensions yet, but we definitely will in the future. - #[allow(dead_code)] - pub extensions: Option>, + pub extensions: HashMap, + pub extra_params: HashMap, } -fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result -where - T: Default + Deserialize<'de>, - D: Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) +// Workaround for https://github.com/cloudwego/sonic-rs/issues/114 + +impl<'de> Deserialize<'de> for ExecutionRequest { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct GraphQLErrorExtensionsVisitor; + + impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor { + type Value = ExecutionRequest; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map for GraphQLErrorExtensions") + } + + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut query = None; + let mut operation_name = None; + let mut variables: Option> = None; + let mut extensions: Option> = None; + let mut extra_params = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "query" => { + if query.is_some() { + return Err(de::Error::duplicate_field("query")); + } + query = map.next_value::>()?; + } + "operationName" => { + if operation_name.is_some() { + return Err(de::Error::duplicate_field("operationName")); + } + operation_name = map.next_value::>()?; + } + "variables" => { + if variables.is_some() { + return Err(de::Error::duplicate_field("variables")); + } + variables = map.next_value::>>()?; + } + "extensions" => { + if extensions.is_some() { + return Err(de::Error::duplicate_field("extensions")); + } + extensions = map.next_value::>>()?; + } + other => { + let value: Value = map.next_value()?; + extra_params.insert(other.to_string(), value); + } + } + } + + Ok(ExecutionRequest { + query, + operation_name, + variables: variables.unwrap_or_default(), + extensions: extensions.unwrap_or_default(), + extra_params, + }) + } + } + + deserializer.deserialize_map(GraphQLErrorExtensionsVisitor) + } } impl TryInto for GETQueryParams { type Error = PipelineErrorVariant; fn try_into(self) -> Result { - let query = match self.query { - Some(q) => q, - None => return Err(PipelineErrorVariant::GetMissingQueryParam("query")), - }; - let variables = match self.variables.as_deref() { Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) { Ok(vars) => vars, @@ -62,19 +121,20 @@ impl TryInto for GETQueryParams { let extensions = match self.extensions.as_deref() { Some(e_str) if !e_str.is_empty() => match sonic_rs::from_str(e_str) { - Ok(exts) => Some(exts), + Ok(exts) => exts, Err(e) => { return Err(PipelineErrorVariant::FailedToParseExtensions(e)); } }, - _ => None, + _ => HashMap::new(), }; let execution_request = ExecutionRequest { - query, + query: self.query, operation_name: self.operation_name, variables, extensions, + extra_params: self.extra_params, }; Ok(execution_request) @@ -90,20 +150,24 @@ pub async fn get_execution_request( let execution_request: ExecutionRequest = match *http_method { Method::GET => { trace!("processing GET GraphQL operation"); - let query_params_str = req.uri().query().ok_or_else(|| { - req.new_pipeline_error(PipelineErrorVariant::GetInvalidQueryParams) - })?; - let query_params = Query::::from_query(query_params_str) - .map_err(|e| { - req.new_pipeline_error(PipelineErrorVariant::GetUnprocessableQueryParams(e)) - })? - .0; - - trace!("parsed GET query params: {:?}", query_params); - - query_params - .try_into() - .map_err(|err| req.new_pipeline_error(err))? + match req.uri().query() { + Some(query_params_str) => { + let query_params = Query::::from_query(query_params_str) + .map_err(|e| { + req.new_pipeline_error( + PipelineErrorVariant::GetUnprocessableQueryParams(e), + ) + })? + .0; + + trace!("parsed GET query params: {:?}", query_params); + + query_params + .try_into() + .map_err(|err| req.new_pipeline_error(err))? + } + None => ExecutionRequest::default(), + } } Method::POST => { trace!("Processing POST GraphQL request"); @@ -132,3 +196,12 @@ pub async fn get_execution_request( Ok(execution_request) } + +impl ExecutionRequest { + pub fn get_query_str(&self) -> Result<&str, PipelineErrorVariant> { + match &self.query { + Some(query_str) => Ok(query_str.as_str()), + None => Err(PipelineErrorVariant::GetMissingQueryParam("query")), + } + } +} diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 2b4721972..08fa7a4ae 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -12,6 +12,7 @@ use ntex::{ util::Bytes, web::{self, HttpRequest}, }; +use tracing::trace; use crate::{ jwt::context::JwtRequestContext, @@ -114,6 +115,32 @@ pub async fn execute_pipeline( perform_csrf_prevention(req, &shared_state.router_config.csrf)?; let mut execution_request = get_execution_request(req, body_bytes).await?; + + trace!("building JWT request details"); + let req_extensions = req.extensions(); + let jwt_context = req_extensions.get::(); + let jwt_request_details = match jwt_context { + Some(jwt_context) => JwtRequestDetails::Authenticated { + token: jwt_context.token_raw.as_str(), + prefix: jwt_context.token_prefix.as_deref(), + scopes: jwt_context.extract_scopes(), + claims: &jwt_context + .get_claims_value() + .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?, + }, + None => JwtRequestDetails::Unauthenticated, + }; + + if let Some(persisted_docs) = &shared_state.persisted_docs { + trace!("handling persisted documents"); + persisted_docs + .handle(&mut execution_request, req, &jwt_request_details) + .await + .map_err(|e| { + req.new_pipeline_error(PipelineErrorVariant::PersistedDocumentsError(e)) + })?; + } + let parser_payload = parse_operation_with_cache(req, shared_state, &execution_request).await?; validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) .await?; @@ -132,24 +159,11 @@ pub async fn execute_pipeline( let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); - let req_extensions = req.extensions(); - let jwt_context = req_extensions.get::(); - let jwt_request_details = match jwt_context { - Some(jwt_context) => JwtRequestDetails::Authenticated { - token: jwt_context.token_raw.as_str(), - prefix: jwt_context.token_prefix.as_deref(), - scopes: jwt_context.extract_scopes(), - claims: &jwt_context - .get_claims_value() - .map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?, - }, - None => JwtRequestDetails::Unauthenticated, - }; - let client_request_details = ClientRequestDetails { method: req.method(), url: req.uri(), headers: req.headers(), + path_params: req.match_info(), operation: OperationDetails { name: normalize_payload.operation_for_plan.name.as_deref(), kind: match normalize_payload.operation_for_plan.operation_kind { @@ -158,7 +172,9 @@ pub async fn execute_pipeline( Some(OperationKind::Subscription) => "subscription", None => "query", }, - query: &execution_request.query, + query: execution_request + .get_query_str() + .map_err(|e| req.new_pipeline_error(e))?, }, jwt: &jwt_request_details, }; diff --git a/bin/router/src/pipeline/parser.rs b/bin/router/src/pipeline/parser.rs index 6e8a37141..79384dac6 100644 --- a/bin/router/src/pipeline/parser.rs +++ b/bin/router/src/pipeline/parser.rs @@ -23,9 +23,12 @@ pub async fn parse_operation_with_cache( app_state: &Arc, execution_params: &ExecutionRequest, ) -> Result { + let query = execution_params + .get_query_str() + .map_err(|err| req.new_pipeline_error(err))?; let cache_key = { let mut hasher = Xxh3::new(); - execution_params.query.hash(&mut hasher); + query.hash(&mut hasher); hasher.finish() }; @@ -33,11 +36,11 @@ pub async fn parse_operation_with_cache( trace!("Found cached parsed operation for query"); cached } else { - let parsed = safe_parse_operation(&execution_params.query).map_err(|err| { + let parsed = safe_parse_operation(query).map_err(|err| { error!("Failed to parse GraphQL operation: {}", err); req.new_pipeline_error(PipelineErrorVariant::FailedToParseOperation(err)) })?; - trace!("sucessfully parsed GraphQL operation"); + trace!("sucesssfully parsed GraphQL operation"); let parsed_arc = Arc::new(parsed); app_state .parse_cache diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index d0b09c183..0c2ad9c15 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -1,20 +1,22 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use hive_router_config::override_labels::{LabelOverrideValue, OverrideLabelsConfig}; -use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails; +use hive_router_plan_executor::{ + execution::client_request_details::ClientRequestDetails, utils::expression::compile_expression, +}; use hive_router_query_planner::{ graph::{PlannerOverrideContext, PERCENTAGE_SCALE_FACTOR}, state::supergraph_state::SupergraphState, }; use rand::Rng; use vrl::{ - compiler::{compile as vrl_compile, Program as VrlProgram, TargetValue as VrlTargetValue}, + compiler::Program as VrlProgram, + compiler::TargetValue as VrlTargetValue, core::Value as VrlValue, prelude::{ state::RuntimeState as VrlState, Context as VrlContext, ExpressionError, TimeZone as VrlTimeZone, }, - stdlib::all as vrl_build_functions, value::Secrets as VrlSecrets, }; @@ -126,7 +128,6 @@ impl OverrideLabelsEvaluator { ) -> Result { let mut static_enabled_labels = HashSet::new(); let mut expressions = HashMap::new(); - let vrl_functions = vrl_build_functions(); for (label, value) in override_labels_config.iter() { match value { @@ -134,19 +135,13 @@ impl OverrideLabelsEvaluator { static_enabled_labels.insert(label.clone()); } LabelOverrideValue::Expression { expression } => { - let compilation_result = - vrl_compile(expression, &vrl_functions).map_err(|diagnostics| { - OverrideLabelsCompileError { - label: label.clone(), - error: diagnostics - .errors() - .into_iter() - .map(|d| d.code.to_string() + ": " + &d.message) - .collect::>() - .join(", "), - } - })?; - expressions.insert(label.clone(), compilation_result.program); + let program = compile_expression(expression, None).map_err(|err| { + OverrideLabelsCompileError { + label: label.clone(), + error: err.to_string(), + } + })?; + expressions.insert(label.clone(), program); } _ => {} // Skip false booleans } diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index f36bda6cd..10a5da6c7 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -7,6 +7,7 @@ use moka::future::Cache; use std::sync::Arc; use crate::jwt::JwtAuthRuntime; +use crate::persisted_documents::{PersistedDocumentsError, PersistedDocumentsLoader}; use crate::pipeline::cors::{CORSConfigError, Cors}; use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator}; @@ -18,12 +19,14 @@ pub struct RouterSharedState { pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, + pub persisted_docs: Option, } impl RouterSharedState { pub fn new( router_config: Arc, jwt_auth_runtime: Option, + persisted_docs: Option, ) -> Result { Ok(Self { validation_plan: graphql_tools::validation::rules::default_rules_validation_plan(), @@ -36,6 +39,7 @@ impl RouterSharedState { ) .map_err(Box::new)?, jwt_auth_runtime, + persisted_docs, }) } } @@ -48,4 +52,6 @@ pub enum SharedStateError { CORSConfig(#[from] Box), #[error("invalid override labels config: {0}")] OverrideLabelsCompile(#[from] Box), + #[error("failed to build the persisted documents manager: {0}")] + PersistedDocuments(#[from] Box), } diff --git a/docs/README.md b/docs/README.md index fb474b9d2..33077cb8d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,14 +8,15 @@ |[**csrf**](#csrf)|`object`|Configuration for CSRF prevention.
Default: `{"enabled":false,"required_headers":[]}`
|| |[**graphiql**](#graphiql)|`object`|Configuration for the GraphiQL interface.
Default: `{"enabled":true}`
|| |[**headers**](#headers)|`object`|Configuration for the headers.
Default: `{}`
|| -|[**http**](#http)|`object`|Configuration for the HTTP server/listener.
Default: `{"host":"0.0.0.0","port":4000}`
|| +|[**http**](#http)|`object`|Configuration for the HTTP server/listener.
Default: `{"graphql_endpoint":"/graphql","host":"0.0.0.0","port":4000}`
|| |[**jwt**](#jwt)|`object`|Configuration for JWT authentication plugin.
|yes| |[**log**](#log)|`object`|The router logger configuration.
Default: `{"filter":null,"format":"json","level":"info"}`
|| |[**override\_labels**](#override_labels)|`object`|Configuration for overriding labels.
|| |[**override\_subgraph\_urls**](#override_subgraph_urls)|`object`|Configuration for overriding subgraph URLs.
Default: `{}`
|| +|[**persisted\_documents**](#persisted_documents)|`object`|Configuration for persisted operations
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
|| -|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"dedupe_enabled":true,"max_connections_per_host":100,"pool_idle_timeout_seconds":50}`
|| +|[**traffic\_shaping**](#traffic_shaping)|`object`|Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs.
Default: `{"dedupe_enabled":true,"max_connections_per_host":100,"pool_idle_timeout":"50s"}`
|| **Additional Properties:** not allowed **Example** @@ -58,6 +59,7 @@ headers: named: x-tenant-id rename: x-acct-tenant http: + graphql_endpoint: /graphql host: 0.0.0.0 port: 4000 jwt: @@ -102,6 +104,19 @@ override_subgraph_urls: .original_url } +persisted_documents: + allow_arbitrary_operations: false + enabled: false + source: + hive: + accept_invalid_certs: false + cache_size: 1000 + connect_timeout: 5s + endpoint: '' + key: '' + request_timeout: 15s + retry_count: 3 + spec: hive query_planner: allow_expose: false timeout: 10s @@ -109,7 +124,7 @@ supergraph: {} traffic_shaping: dedupe_enabled: true max_connections_per_host: 100 - pool_idle_timeout_seconds: 50 + pool_idle_timeout: 50s ``` @@ -1351,6 +1366,7 @@ Configuration for the HTTP server/listener. |Name|Type|Description|Required| |----|----|-----------|--------| +|**graphql\_endpoint**|`string`|Default: `"/graphql"`
|| |**host**|`string`|The host address to bind the HTTP server to.

Can also be set via the `HOST` environment variable.
Default: `"0.0.0.0"`
|| |**port**|`integer`|The port to bind the HTTP server to.

Can also be set via the `PORT` environment variable.

If you are running the router inside a Docker container, please ensure that the port is exposed correctly using `-p :` flag.
Default: `4000`
Format: `"uint16"`
Minimum: `0`
Maximum: `65535`
|| @@ -1358,6 +1374,7 @@ Configuration for the HTTP server/listener. **Example** ```yaml +graphql_endpoint: /graphql host: 0.0.0.0 port: 4000 @@ -1641,6 +1658,40 @@ products: |----|----|-----------|--------| |**url**||Overrides for the URL of the subgraph.

For convenience, a plain string in your configuration will be treated as a static URL.

### Static URL Example
```yaml
url: "https://api.example.com/graphql"
```

### Dynamic Expression Example

The expression has access to the following variables:
- `request`: The incoming HTTP request, including headers and other metadata.
- `original_url`: The original URL of the subgraph (from supergraph sdl).

```yaml
url:
expression: \|
if .request.headers."x-region" == "us-east" {
"https://products-us-east.example.com/graphql"
} else if .request.headers."x-region" == "eu-west" {
"https://products-eu-west.example.com/graphql"
} else {
.original_url
}
|yes| +
+## persisted\_documents: object + +Configuration for persisted operations + + +**Properties** + +|Name|Type|Description|Required| +|----|----|-----------|--------| +|**allow\_arbitrary\_operations**||Whether to allow arbitrary operations that are not persisted.
Default: `false`
|| +|**enabled**|`boolean`|Whether persisted operations are enabled.
Default: `false`
|| +|**source**||The source of persisted documents.
Default: `{"hive":{"accept_invalid_certs":false,"cache_size":1000,"connect_timeout":"5s","endpoint":"","key":"","request_timeout":"15s","retry_count":3}}`
|| +|**spec**||The specification to extract persisted operations.
Default: `"hive"`
|| + +**Additional Properties:** not allowed +**Example** + +```yaml +allow_arbitrary_operations: false +enabled: false +source: + hive: + accept_invalid_certs: false + cache_size: 1000 + connect_timeout: 5s + endpoint: '' + key: '' + request_timeout: 15s + retry_count: 3 +spec: hive + +``` + ## query\_planner: object @@ -1808,7 +1859,7 @@ Request timeout for the Hive Console CDN requests. ## traffic\_shaping: object -Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs. +Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs. **Properties** @@ -1817,7 +1868,7 @@ Configuration for the traffic-shaper executor. Use these configurations to contr |----|----|-----------|--------| |**dedupe\_enabled**|`boolean`|Enables/disables request deduplication to subgraphs.

When requests exactly matches the hashing mechanism (e.g., subgraph name, URL, headers, query, variables), and are executed at the same time, they will
be deduplicated by sharing the response of other in-flight requests.
Default: `true`
|| |**max\_connections\_per\_host**|`integer`|Limits the concurrent amount of requests/connections per host/subgraph.
Default: `100`
Format: `"uint"`
Minimum: `0`
|| -|**pool\_idle\_timeout\_seconds**|`integer`|Timeout for idle sockets being kept-alive.
Default: `50`
Format: `"uint64"`
Minimum: `0`
|| +|**pool\_idle\_timeout**|`string`|Timeout for idle sockets being kept-alive.
Default: `"50s"`
|| **Additional Properties:** not allowed **Example** @@ -1825,7 +1876,7 @@ Configuration for the traffic-shaper executor. Use these configurations to contr ```yaml dedupe_enabled: true max_connections_per_host: 100 -pool_idle_timeout_seconds: 50 +pool_idle_timeout: 50s ``` diff --git a/e2e/configs/persisted_documents/apollo_spec.yaml b/e2e/configs/persisted_documents/apollo_spec.yaml new file mode 100644 index 000000000..2094120cb --- /dev/null +++ b/e2e/configs/persisted_documents/apollo_spec.yaml @@ -0,0 +1,10 @@ +# yaml-language-server: $schema=../../../router-config.schema.json +supergraph: + source: file + path: ../../supergraph.graphql +persisted_documents: + enabled: true + source: + file: + path: ../../persisted_docs.json + spec: apollo \ No newline at end of file diff --git a/e2e/configs/persisted_documents/expr_allow_arbitrary.yaml b/e2e/configs/persisted_documents/expr_allow_arbitrary.yaml new file mode 100644 index 000000000..df69b5ad0 --- /dev/null +++ b/e2e/configs/persisted_documents/expr_allow_arbitrary.yaml @@ -0,0 +1,12 @@ +# yaml-language-server: $schema=../../../router-config.schema.json +supergraph: + source: file + path: ../../supergraph.graphql +persisted_documents: + enabled: true + source: + file: + path: ../../persisted_docs.json + allow_arbitrary_operations: + expression: | + .request.headers."x-allow-arbitrary-operations" == "true" \ No newline at end of file diff --git a/e2e/configs/persisted_documents/expr_spec.yaml b/e2e/configs/persisted_documents/expr_spec.yaml new file mode 100644 index 000000000..60f7ce8f5 --- /dev/null +++ b/e2e/configs/persisted_documents/expr_spec.yaml @@ -0,0 +1,11 @@ +# yaml-language-server: $schema=../../../router-config.schema.json +supergraph: + source: file + path: ../../supergraph.graphql +persisted_documents: + enabled: true + source: + file: + path: ../../persisted_docs.json + spec: + expression: .request.body.my_id diff --git a/e2e/configs/persisted_documents/file_source.yaml b/e2e/configs/persisted_documents/file_source.yaml new file mode 100644 index 000000000..b70b46013 --- /dev/null +++ b/e2e/configs/persisted_documents/file_source.yaml @@ -0,0 +1,9 @@ +# yaml-language-server: $schema=../../../router-config.schema.json +supergraph: + source: file + path: ../../supergraph.graphql +persisted_documents: + enabled: true + source: + file: + path: ../../persisted_docs.json \ No newline at end of file diff --git a/e2e/configs/persisted_documents/url_spec.yaml b/e2e/configs/persisted_documents/url_spec.yaml new file mode 100644 index 000000000..f92334dc9 --- /dev/null +++ b/e2e/configs/persisted_documents/url_spec.yaml @@ -0,0 +1,13 @@ +# yaml-language-server: $schema=../../../router-config.schema.json +supergraph: + source: file + path: ../../supergraph.graphql +http: + graphql_endpoint: /graphql/{document_id} +persisted_documents: + enabled: true + source: + file: + path: ../../persisted_docs.json + spec: + expression: .request.path_params.document_id diff --git a/e2e/persisted_docs.json b/e2e/persisted_docs.json new file mode 100644 index 000000000..590eea858 --- /dev/null +++ b/e2e/persisted_docs.json @@ -0,0 +1,3 @@ +{ + "simple": "{ users { id } }" +} \ No newline at end of file diff --git a/e2e/src/lib.rs b/e2e/src/lib.rs index 9086e01f4..f39b6ccf4 100644 --- a/e2e/src/lib.rs +++ b/e2e/src/lib.rs @@ -7,6 +7,8 @@ mod jwt; #[cfg(test)] mod override_subgraph_urls; #[cfg(test)] +mod persisted_documents; +#[cfg(test)] mod probes; #[cfg(test)] mod supergraph; diff --git a/e2e/src/persisted_documents.rs b/e2e/src/persisted_documents.rs new file mode 100644 index 000000000..ae8fd2ddd --- /dev/null +++ b/e2e/src/persisted_documents.rs @@ -0,0 +1,181 @@ +#[cfg(test)] +mod persisted_documents_e2e_tests { + use ntex::web::test; + use sonic_rs::json; + + use crate::testkit::{init_router_from_config_file, wait_for_readiness, SubgraphsServer}; + + #[ntex::test] + /// Tests a simple persisted document from a file retrieval using the "hive" spec. + async fn should_get_persisted_document_from_a_file() { + let subgraphs_server = SubgraphsServer::start().await; + let app = init_router_from_config_file("configs/persisted_documents/file_source.yaml") + .await + .unwrap(); + wait_for_readiness(&app.app).await; + + let body = json!({ + "documentId": "simple", + }); + + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!(resp.status().is_success(), "Expected 200 OK"); + + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + assert_eq!( + subgraph_requests.len(), + 1, + "expected 1 request to accounts subgraph" + ); + } + + #[ntex::test] + /// Tests a persisted document retrieval using a custom expression spec. + async fn should_support_custom_spec() { + let subgraphs_server = SubgraphsServer::start().await; + let app = init_router_from_config_file("configs/persisted_documents/expr_spec.yaml") + .await + .unwrap(); + wait_for_readiness(&app.app).await; + + let body = json!({ + "my_id": "simple", + }); + + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!(resp.status().is_success(), "Expected 200 OK"); + + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + assert_eq!( + subgraph_requests.len(), + 1, + "expected 1 request to accounts subgraph" + ); + } + + #[ntex::test] + /// Tests if arbitrary operations are allowed based on a custom expression. + async fn should_allow_arbitrary_operations_based_on_expression() { + let subgraphs_server = SubgraphsServer::start().await; + let app = + init_router_from_config_file("configs/persisted_documents/expr_allow_arbitrary.yaml") + .await + .unwrap(); + wait_for_readiness(&app.app).await; + + let arbitrary = json!({ + "query": "{ users { id } }", + }); + + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .header("x-allow-arbitrary-operations", "true") + .set_payload(arbitrary.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert!(resp.status().is_success(), "Expected 200 OK"); + + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + assert_eq!( + subgraph_requests.len(), + 1, + "expected 1 request to accounts subgraph" + ); + + let not_allowed = json!({ + "query": "{ users { id } }", + }); + + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(not_allowed.to_string()); + let resp = test::call_service(&app.app, req.to_request()).await; + assert_eq!(resp.status().as_u16(), 400, "Expected 400 Bad Request"); + } + + #[ntex::test] + /// Tests if arbitrary operations are allowed based on the Apollo spec. + async fn should_support_apollo_spec() { + let subgraphs_server = SubgraphsServer::start().await; + let app = init_router_from_config_file("configs/persisted_documents/apollo_spec.yaml") + .await + .unwrap(); + wait_for_readiness(&app.app).await; + + let body = json!({ + "extensions": { + "persistedQuery": { + "version": 1, + "sha256Hash": "simple" + } + } + }); + + let req = test::TestRequest::post() + .uri("/graphql") + .header("content-type", "application/json") + .set_payload(body.to_string()); + + let resp = test::call_service(&app.app, req.to_request()).await; + assert!(resp.status().is_success(), "Expected 200 OK"); + + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + + assert_eq!( + subgraph_requests.len(), + 1, + "expected 1 request to accounts subgraph" + ); + } + #[ntex::test] + async fn should_support_url_params() { + let subgraphs_server = SubgraphsServer::start().await; + let app = init_router_from_config_file("configs/persisted_documents/url_spec.yaml") + .await + .unwrap(); + wait_for_readiness(&app.app).await; + + let req = test::TestRequest::get().uri("/graphql/simple"); + let resp = test::call_service(&app.app, req.to_request()).await; + + let status = resp.status(); + let body = test::read_body(resp).await; + assert!( + status.is_success(), + "Expected 200 OK, got {} with body {:#?}", + status, + body + ); + + let subgraph_requests = subgraphs_server + .get_subgraph_requests_log("accounts") + .await + .expect("expected requests sent to accounts subgraph"); + assert_eq!( + subgraph_requests.len(), + 1, + "expected 1 request to accounts subgraph" + ); + } +} diff --git a/e2e/src/testkit.rs b/e2e/src/testkit.rs index 638138801..4b61c417d 100644 --- a/e2e/src/testkit.rs +++ b/e2e/src/testkit.rs @@ -187,7 +187,9 @@ pub async fn init_router_from_config( web::App::new() .state(shared_state.clone()) .state(schema_state.clone()) - .configure(configure_ntex_app), + .configure(|service_config| { + configure_ntex_app(service_config, &shared_state.router_config); + }), ) .await; diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 27f7af1bf..5f3f3990d 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -30,6 +30,7 @@ xxhash-rust = { workspace = true } tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } vrl = { workspace = true } +ntex = { workspace = true } ahash = "0.8.12" regex-automata = "0.4.10" @@ -49,6 +50,7 @@ itoa = "1.0.15" ryu = "1.0.20" indexmap = "2.10.0" bumpalo = "3.19.0" +once_cell = "1.21.3" [dev-dependencies] subgraphs = { path = "../../bench/subgraphs" } diff --git a/lib/executor/src/execution/client_request_details.rs b/lib/executor/src/execution/client_request_details.rs index 6985376cc..9b9cb9730 100644 --- a/lib/executor/src/execution/client_request_details.rs +++ b/lib/executor/src/execution/client_request_details.rs @@ -1,9 +1,10 @@ use std::collections::BTreeMap; use bytes::Bytes; -use http::Method; +use http::{Method, Uri}; +use ntex::router::Path; use ntex_http::HeaderMap as NtexHeaderMap; -use vrl::core::Value; +use vrl::{core::Value, value::KeyString}; use crate::utils::vrl::sonic_value_to_vrl_value; @@ -16,6 +17,7 @@ pub struct OperationDetails<'exec> { pub struct ClientRequestDetails<'exec, 'req> { pub method: &'req Method, pub url: &'req http::Uri, + pub path_params: &'req Path, pub headers: &'req NtexHeaderMap, pub operation: OperationDetails<'exec>, pub jwt: &'exec JwtRequestDetails<'req>, @@ -31,43 +33,9 @@ pub enum JwtRequestDetails<'exec> { Unauthenticated, } -impl From<&ClientRequestDetails<'_, '_>> for Value { - fn from(details: &ClientRequestDetails) -> Self { - // .request.headers - let headers_value = client_header_map_to_vrl_value(details.headers); - - // .request.url - let url_value = Self::Object(BTreeMap::from([ - ( - "host".into(), - details.url.host().unwrap_or("unknown").into(), - ), - ("path".into(), details.url.path().into()), - ( - "port".into(), - details - .url - .port_u16() - .unwrap_or_else(|| { - if details.url.scheme() == Some(&http::uri::Scheme::HTTPS) { - 443 - } else { - 80 - } - }) - .into(), - ), - ])); - - // .request.operation - let operation_value = Self::Object(BTreeMap::from([ - ("name".into(), details.operation.name.into()), - ("type".into(), details.operation.kind.into()), - ("query".into(), details.operation.query.into()), - ])); - - // .request.jwt - let jwt_value = match details.jwt { +impl From<&JwtRequestDetails<'_>> for Value { + fn from(details: &JwtRequestDetails) -> Self { + match details { JwtRequestDetails::Authenticated { token, prefix, @@ -101,19 +69,40 @@ impl From<&ClientRequestDetails<'_, '_>> for Value { ("claims".into(), Value::Object(BTreeMap::new())), ("scopes".into(), Value::Array(vec![])), ])), - }; + } + } +} + +impl From<&ClientRequestDetails<'_, '_>> for Value { + fn from(details: &ClientRequestDetails) -> Self { + // .request.headers + let headers_value = client_header_map_to_vrl_value(details.headers); + + // .request.url + let url_value = client_url_to_vrl_value(details.url); + + // .request.operation + let operation_value = Self::Object(BTreeMap::from([ + ("name".into(), details.operation.name.into()), + ("type".into(), details.operation.kind.into()), + ("query".into(), details.operation.query.into()), + ])); Self::Object(BTreeMap::from([ ("method".into(), details.method.as_str().into()), ("headers".into(), headers_value), ("url".into(), url_value), ("operation".into(), operation_value), - ("jwt".into(), jwt_value), + ( + "path_params".into(), + client_path_params_to_vrl_value(details.path_params), + ), + ("jwt".into(), details.jwt.into()), ])) } } -fn client_header_map_to_vrl_value(headers: &ntex_http::HeaderMap) -> Value { +pub fn client_header_map_to_vrl_value(headers: &ntex_http::HeaderMap) -> Value { let mut obj = BTreeMap::new(); for (header_name, header_value) in headers.iter() { if let Ok(value) = header_value.to_str() { @@ -125,3 +114,31 @@ fn client_header_map_to_vrl_value(headers: &ntex_http::HeaderMap) -> Value { } Value::Object(obj) } + +pub fn client_url_to_vrl_value(url: &http::Uri) -> Value { + Value::Object(BTreeMap::from([ + ("host".into(), url.host().unwrap_or("unknown").into()), + ("path".into(), url.path().into()), + ( + "port".into(), + url.port_u16() + .unwrap_or_else(|| { + if url.scheme() == Some(&http::uri::Scheme::HTTPS) { + 443 + } else { + 80 + } + }) + .into(), + ), + ])) +} + +pub fn client_path_params_to_vrl_value(path_params: &Path) -> Value { + Value::Object( + path_params + .iter() + .map(|(k, v)| (k.into(), v.into())) + .collect::>(), + ) +} diff --git a/lib/executor/src/executors/error.rs b/lib/executor/src/executors/error.rs index 2234f524c..de69d47cb 100644 --- a/lib/executor/src/executors/error.rs +++ b/lib/executor/src/executors/error.rs @@ -1,4 +1,4 @@ -use vrl::{diagnostic::DiagnosticList, prelude::ExpressionError}; +use vrl::prelude::ExpressionError; use crate::response::graphql_error::{GraphQLError, GraphQLErrorExtensions}; @@ -34,21 +34,6 @@ impl From for GraphQLError { } impl SubgraphExecutorError { - pub fn new_endpoint_expression_build( - subgraph_name: String, - diagnostics: DiagnosticList, - ) -> Self { - SubgraphExecutorError::EndpointExpressionBuild( - subgraph_name, - diagnostics - .errors() - .into_iter() - .map(|d| d.code.to_string() + ": " + &d.message) - .collect::>() - .join(", "), - ) - } - pub fn new_endpoint_expression_resolution_failure( subgraph_name: String, error: ExpressionError, diff --git a/lib/executor/src/executors/map.rs b/lib/executor/src/executors/map.rs index a3c297ad1..49b936f21 100644 --- a/lib/executor/src/executors/map.rs +++ b/lib/executor/src/executors/map.rs @@ -1,7 +1,6 @@ use std::{ collections::{BTreeMap, HashMap}, sync::Arc, - time::Duration, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -15,16 +14,7 @@ use hyper_util::{ }; use tokio::sync::{OnceCell, Semaphore}; use tracing::error; -use vrl::{ - compiler::compile as vrl_compile, - compiler::Program as VrlProgram, - compiler::TargetValue as VrlTargetValue, - core::Value as VrlValue, - prelude::Function as VrlFunction, - prelude::{state::RuntimeState as VrlState, Context as VrlContext, TimeZone as VrlTimeZone}, - stdlib::all as vrl_build_functions, - value::Secrets as VrlSecrets, -}; +use vrl::{compiler::Program as VrlProgram, core::Value as VrlValue}; use crate::{ execution::client_request_details::ClientRequestDetails, @@ -37,6 +27,7 @@ use crate::{ http::{HTTPSubgraphExecutor, HttpClient}, }, response::graphql_error::GraphQLError, + utils::expression::{compile_expression, execute_expression_with_value}, }; type SubgraphName = String; @@ -54,8 +45,6 @@ pub struct SubgraphExecutorMap { /// Mapping from subgraph name to VRL expression program expressions_by_subgraph: ExpressionsBySubgraphMap, config: Arc, - /// Precompiled VRL functions to be used in endpoint expressions. - vrl_functions: Vec>, client: Arc, semaphores_by_origin: DashMap>, max_connections_per_host: usize, @@ -67,9 +56,7 @@ impl SubgraphExecutorMap { let https = HttpsConnector::new(); let client: HttpClient = Client::builder(TokioExecutor::new()) .pool_timer(TokioTimer::new()) - .pool_idle_timeout(Duration::from_secs( - config.traffic_shaping.pool_idle_timeout_seconds, - )) + .pool_idle_timeout(config.traffic_shaping.pool_idle_timeout) .pool_max_idle_per_host(config.traffic_shaping.max_connections_per_host) .build(https); @@ -80,7 +67,6 @@ impl SubgraphExecutorMap { static_endpoints_by_subgraph: Default::default(), expressions_by_subgraph: Default::default(), config, - vrl_functions: vrl_build_functions(), client: Arc::new(client), semaphores_by_origin: Default::default(), max_connections_per_host, @@ -122,7 +108,7 @@ impl SubgraphExecutorMap { client_request: &ClientRequestDetails<'a, 'req>, ) -> HttpExecutionResponse { match self.get_or_create_executor(subgraph_name, client_request) { - Ok(Some(executor)) => executor.execute(execution_request).await, + Ok(executor) => executor.execute(execution_request).await, Err(err) => { error!( "Subgraph executor error for subgraph '{}': {}", @@ -130,13 +116,6 @@ impl SubgraphExecutorMap { ); self.internal_server_error_response(err.into(), subgraph_name) } - Ok(None) => { - error!( - "Subgraph executor not found for subgraph '{}'", - subgraph_name - ); - self.internal_server_error_response("Internal server error".into(), subgraph_name) - } } } @@ -165,15 +144,22 @@ impl SubgraphExecutorMap { &self, subgraph_name: &str, client_request: &ClientRequestDetails<'_, '_>, - ) -> Result, SubgraphExecutorError> { - let from_expression = - self.get_or_create_executor_from_expression(subgraph_name, client_request)?; - - if from_expression.is_some() { - return Ok(from_expression); - } - - Ok(self.get_executor_from_static_endpoint(subgraph_name)) + ) -> Result { + self.expressions_by_subgraph + .get(subgraph_name) + .map(|expression| { + self.get_or_create_executor_from_expression( + subgraph_name, + expression, + client_request, + ) + }) + .unwrap_or_else(|| { + self.get_executor_from_static_endpoint(subgraph_name) + .ok_or_else(|| { + SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string()) + }) + }) } /// Looks up a subgraph executor, @@ -183,67 +169,43 @@ impl SubgraphExecutorMap { fn get_or_create_executor_from_expression( &self, subgraph_name: &str, + expression: &VrlProgram, client_request: &ClientRequestDetails<'_, '_>, - ) -> Result, SubgraphExecutorError> { - if let Some(expression) = self.expressions_by_subgraph.get(subgraph_name) { - let original_url_value = VrlValue::Bytes(Bytes::from( - self.static_endpoints_by_subgraph - .get(subgraph_name) - .map(|endpoint| endpoint.value().clone()) - .ok_or_else(|| { - SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string()) - })?, - )); - let mut target = VrlTargetValue { - value: VrlValue::Object(BTreeMap::from([ - ("request".into(), client_request.into()), - ("original_url".into(), original_url_value), - ])), - metadata: VrlValue::Object(BTreeMap::new()), - secrets: VrlSecrets::default(), - }; - - let mut state = VrlState::default(); - let timezone = VrlTimeZone::default(); - let mut ctx = VrlContext::new(&mut target, &mut state, &timezone); - - // Resolve the expression to get an endpoint URL. - let endpoint_result = expression.resolve(&mut ctx).map_err(|err| { - SubgraphExecutorError::new_endpoint_expression_resolution_failure( - subgraph_name.to_string(), - err, - ) - })?; - let endpoint_str = match endpoint_result.as_str() { - Some(s) => s.to_string(), - None => { - return Err(SubgraphExecutorError::EndpointExpressionWrongType( - subgraph_name.to_string(), - )); - } - }; - - // Check if an executor for this endpoint already exists. - let existing_executor = self - .executors_by_subgraph + ) -> Result { + let original_url_value = VrlValue::Bytes(Bytes::from( + self.static_endpoints_by_subgraph .get(subgraph_name) - .and_then(|endpoints| endpoints.get(&endpoint_str).map(|e| e.clone())); - - if let Some(executor) = existing_executor { - return Ok(Some(executor)); - } - + .map(|endpoint| endpoint.value().clone()) + .ok_or_else(|| { + SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string()) + })?, + )); + let value = VrlValue::Object(BTreeMap::from([ + ("request".into(), client_request.into()), + ("original_url".into(), original_url_value), + ])); + + // Resolve the expression to get an endpoint URL. + let endpoint_result = execute_expression_with_value(expression, value).map_err(|err| { + SubgraphExecutorError::new_endpoint_expression_resolution_failure( + subgraph_name.to_string(), + err, + ) + })?; + let endpoint_str = match endpoint_result.as_str() { + Some(s) => Ok(s), + None => Err(SubgraphExecutorError::EndpointExpressionWrongType( + subgraph_name.to_string(), + )), + }?; + + // Check if an executor for this endpoint already exists. + self.executors_by_subgraph + .get(subgraph_name) + .and_then(|endpoints| endpoints.get(endpoint_str.as_ref()).map(|e| e.clone())) + .map(Ok) // If not, create and register a new one. - self.register_executor(subgraph_name, &endpoint_str)?; - - let endpoints = self - .executors_by_subgraph - .get(subgraph_name) - .expect("Executor was just registered, should be present"); - return Ok(endpoints.get(&endpoint_str).map(|e| e.clone())); - } - - Ok(None) + .unwrap_or_else(|| self.register_executor(subgraph_name, endpoint_str.as_ref())) } /// Looks up a subgraph executor based on a static endpoint URL. @@ -269,12 +231,11 @@ impl SubgraphExecutorMap { subgraph_name: &str, expression: &str, ) -> Result<(), SubgraphExecutorError> { - let compilation_result = vrl_compile(expression, &self.vrl_functions).map_err(|e| { - SubgraphExecutorError::new_endpoint_expression_build(subgraph_name.to_string(), e) + let program = compile_expression(expression, None).map_err(|err| { + SubgraphExecutorError::EndpointExpressionBuild(subgraph_name.to_string(), err) })?; - self.expressions_by_subgraph - .insert(subgraph_name.to_string(), compilation_result.program); + .insert(subgraph_name.to_string(), program); Ok(()) } @@ -293,7 +254,7 @@ impl SubgraphExecutorMap { &self, subgraph_name: &str, endpoint_str: &str, - ) -> Result<(), SubgraphExecutorError> { + ) -> Result { let endpoint_uri = endpoint_str.parse::().map_err(|e| { SubgraphExecutorError::EndpointParseFailure(endpoint_str.to_string(), e.to_string()) })?; @@ -326,11 +287,13 @@ impl SubgraphExecutorMap { self.in_flight_requests.clone(), ); + let executor_arc = executor.to_boxed_arc(); + self.executors_by_subgraph .entry(subgraph_name.to_string()) .or_default() - .insert(endpoint_str.to_string(), executor.to_boxed_arc()); + .insert(endpoint_str.to_string(), executor_arc.clone()); - Ok(()) + Ok(executor_arc) } } diff --git a/lib/executor/src/headers/compile.rs b/lib/executor/src/headers/compile.rs index 5b1b14a9f..587cbee60 100644 --- a/lib/executor/src/headers/compile.rs +++ b/lib/executor/src/headers/compile.rs @@ -1,54 +1,28 @@ -use crate::headers::{ - errors::HeaderRuleCompileError, - plan::{ - HeaderAggregationStrategy, HeaderRulesPlan, RequestHeaderRule, RequestHeaderRules, - RequestInsertExpression, RequestInsertStatic, RequestPropagateNamed, RequestPropagateRegex, - RequestRemoveNamed, RequestRemoveRegex, ResponseHeaderRule, ResponseHeaderRules, - ResponseInsertExpression, ResponseInsertStatic, ResponsePropagateNamed, - ResponsePropagateRegex, ResponseRemoveNamed, ResponseRemoveRegex, +use crate::{ + headers::{ + errors::HeaderRuleCompileError, + plan::{ + HeaderAggregationStrategy, HeaderRulesPlan, RequestHeaderRule, RequestHeaderRules, + RequestInsertExpression, RequestInsertStatic, RequestPropagateNamed, + RequestPropagateRegex, RequestRemoveNamed, RequestRemoveRegex, ResponseHeaderRule, + ResponseHeaderRules, ResponseInsertExpression, ResponseInsertStatic, + ResponsePropagateNamed, ResponsePropagateRegex, ResponseRemoveNamed, + ResponseRemoveRegex, + }, }, + utils::expression::compile_expression, }; use hive_router_config::headers as config; use http::HeaderName; use regex_automata::{meta, util::syntax::Config as SyntaxConfig}; -use vrl::{ - compiler::compile as vrl_compile, prelude::Function as VrlFunction, - stdlib::all as vrl_build_functions, -}; - -pub struct HeaderRuleCompilerContext { - vrl_functions: Vec>, -} - -impl Default for HeaderRuleCompilerContext { - fn default() -> Self { - Self::new() - } -} - -impl HeaderRuleCompilerContext { - pub fn new() -> Self { - Self { - vrl_functions: vrl_build_functions(), - } - } -} pub trait HeaderRuleCompiler { - fn compile( - &self, - ctx: &HeaderRuleCompilerContext, - actions: &mut A, - ) -> Result<(), HeaderRuleCompileError>; + fn compile(&self, actions: &mut A) -> Result<(), HeaderRuleCompileError>; } impl HeaderRuleCompiler> for config::RequestHeaderRule { - fn compile( - &self, - ctx: &HeaderRuleCompilerContext, - actions: &mut Vec, - ) -> Result<(), HeaderRuleCompileError> { + fn compile(&self, actions: &mut Vec) -> Result<(), HeaderRuleCompileError> { match self { config::RequestHeaderRule::Propagate(rule) => { let spec = materialize_match_spec( @@ -79,15 +53,13 @@ impl HeaderRuleCompiler> for config::RequestHeaderRule { })); } config::InsertSource::Expression { expression } => { - let compilation_result = - vrl_compile(expression, &ctx.vrl_functions).map_err(|e| { - HeaderRuleCompileError::new_expression_build(rule.name.clone(), e) - })?; - + let program = compile_expression(expression, None).map_err(|err| { + HeaderRuleCompileError::ExpressionBuild(rule.name.clone(), err) + })?; actions.push(RequestHeaderRule::InsertExpression( RequestInsertExpression { name: build_header_name(&rule.name)?, - expression: Box::new(compilation_result.program), + expression: Box::new(program), }, )); } @@ -112,11 +84,7 @@ impl HeaderRuleCompiler> for config::RequestHeaderRule { } impl HeaderRuleCompiler> for config::ResponseHeaderRule { - fn compile( - &self, - ctx: &HeaderRuleCompilerContext, - actions: &mut Vec, - ) -> Result<(), HeaderRuleCompileError> { + fn compile(&self, actions: &mut Vec) -> Result<(), HeaderRuleCompileError> { match self { config::ResponseHeaderRule::Propagate(rule) => { let aggregation_strategy = rule.algorithm.into(); @@ -159,15 +127,13 @@ impl HeaderRuleCompiler> for config::ResponseHeaderRule // - compilation_result.program.info().target_assignments // - compilation_result.program.info().target_queries // to determine what parts of the context are actually needed by the expression - let compilation_result = vrl_compile(expression, &ctx.vrl_functions) - .map_err(|e| { - HeaderRuleCompileError::new_expression_build(rule.name.clone(), e) - })?; - + let program = compile_expression(expression, None).map_err(|err| { + HeaderRuleCompileError::ExpressionBuild(rule.name.clone(), err) + })?; actions.push(ResponseHeaderRule::InsertExpression( ResponseInsertExpression { name: build_header_name(&rule.name)?, - expression: Box::new(compilation_result.program), + expression: Box::new(program), strategy: aggregation_strategy, }, )); @@ -196,19 +162,18 @@ impl HeaderRuleCompiler> for config::ResponseHeaderRule pub fn compile_headers_plan( cfg: &config::HeadersConfig, ) -> Result { - let ctx = HeaderRuleCompilerContext::new(); let mut request_plan = RequestHeaderRules::default(); let mut response_plan = ResponseHeaderRules::default(); if let Some(global_rules) = &cfg.all { - request_plan.global = compile_request_header_rules(&ctx, global_rules)?; - response_plan.global = compile_response_header_rules(&ctx, global_rules)?; + request_plan.global = compile_request_header_rules(global_rules)?; + response_plan.global = compile_response_header_rules(global_rules)?; } if let Some(subgraph_rules_map) = &cfg.subgraphs { for (subgraph_name, subgraph_rules) in subgraph_rules_map { - let request_actions = compile_request_header_rules(&ctx, subgraph_rules)?; - let response_actions = compile_response_header_rules(&ctx, subgraph_rules)?; + let request_actions = compile_request_header_rules(subgraph_rules)?; + let response_actions = compile_response_header_rules(subgraph_rules)?; request_plan .by_subgraph .insert(subgraph_name.clone(), request_actions); @@ -225,26 +190,24 @@ pub fn compile_headers_plan( } fn compile_request_header_rules( - ctx: &HeaderRuleCompilerContext, header_rules: &config::HeaderRules, ) -> Result, HeaderRuleCompileError> { let mut request_actions = Vec::new(); if let Some(request_rule_entries) = &header_rules.request { for request_rule in request_rule_entries { - request_rule.compile(ctx, &mut request_actions)?; + request_rule.compile(&mut request_actions)?; } } Ok(request_actions) } fn compile_response_header_rules( - ctx: &HeaderRuleCompilerContext, header_rules: &config::HeaderRules, ) -> Result, HeaderRuleCompileError> { let mut response_actions = Vec::new(); if let Some(response_rule_entries) = &header_rules.response { for response_rule in response_rule_entries { - response_rule.compile(ctx, &mut response_actions)?; + response_rule.compile(&mut response_actions)?; } } Ok(response_actions) @@ -358,7 +321,7 @@ mod tests { use http::HeaderName; use crate::headers::{ - compile::{build_header_value, HeaderRuleCompiler, HeaderRuleCompilerContext}, + compile::{build_header_value, HeaderRuleCompiler}, errors::HeaderRuleCompileError, plan::{HeaderAggregationStrategy, RequestHeaderRule, ResponseHeaderRule}, }; @@ -378,9 +341,8 @@ mod tests { rename: None, default: None, }); - let ctx = HeaderRuleCompilerContext::new(); let mut actions = Vec::new(); - rule.compile(&ctx, &mut actions).unwrap(); + rule.compile(&mut actions).unwrap(); assert_eq!(actions.len(), 1); match &actions[0] { RequestHeaderRule::PropagateNamed(data) => { @@ -401,8 +363,7 @@ mod tests { }, }); let mut actions = Vec::new(); - let ctx = HeaderRuleCompilerContext::new(); - rule.compile(&ctx, &mut actions).unwrap(); + rule.compile(&mut actions).unwrap(); assert_eq!(actions.len(), 1); match &actions[0] { RequestHeaderRule::InsertStatic(data) => { @@ -423,8 +384,7 @@ mod tests { }, }); let mut actions = Vec::new(); - let ctx = HeaderRuleCompilerContext::new(); - rule.compile(&ctx, &mut actions).unwrap(); + rule.compile(&mut actions).unwrap(); assert_eq!(actions.len(), 1); match &actions[0] { RequestHeaderRule::RemoveNamed(data) => { @@ -449,8 +409,7 @@ mod tests { default: Some("def".to_string()), }); let mut actions = Vec::new(); - let ctx = HeaderRuleCompilerContext::new(); - let err = rule.compile(&ctx, &mut actions).unwrap_err(); + let err = rule.compile(&mut actions).unwrap_err(); match err { HeaderRuleCompileError::InvalidDefault => {} _ => panic!("Expected InvalidDefault error"), @@ -470,8 +429,7 @@ mod tests { algorithm: config::AggregationAlgo::First, }); let mut actions = Vec::new(); - let ctx = HeaderRuleCompilerContext::new(); - rule.compile(&ctx, &mut actions).unwrap(); + rule.compile(&mut actions).unwrap(); assert_eq!(actions.len(), 1); match &actions[0] { ResponseHeaderRule::PropagateNamed(data) => { diff --git a/lib/executor/src/headers/errors.rs b/lib/executor/src/headers/errors.rs index d53444877..6c2a806d2 100644 --- a/lib/executor/src/headers/errors.rs +++ b/lib/executor/src/headers/errors.rs @@ -1,6 +1,6 @@ use http::header::{InvalidHeaderName, InvalidHeaderValue}; use regex_automata::meta::BuildError; -use vrl::{diagnostic::DiagnosticList, prelude::ExpressionError}; +use vrl::prelude::ExpressionError; #[derive(thiserror::Error, Debug)] pub enum HeaderRuleCompileError { @@ -27,16 +27,8 @@ pub enum HeaderRuleRuntimeError { } impl HeaderRuleCompileError { - pub fn new_expression_build(header_name: String, diagnostics: DiagnosticList) -> Self { - HeaderRuleCompileError::ExpressionBuild( - header_name, - diagnostics - .errors() - .into_iter() - .map(|d| d.code.to_string() + ": " + &d.message) - .collect::>() - .join(", "), - ) + pub fn new_expression_build(header_name: String, err: String) -> Self { + HeaderRuleCompileError::ExpressionBuild(header_name, err) } } diff --git a/lib/executor/src/headers/mod.rs b/lib/executor/src/headers/mod.rs index 62f9fe701..52ba7eac6 100644 --- a/lib/executor/src/headers/mod.rs +++ b/lib/executor/src/headers/mod.rs @@ -21,6 +21,7 @@ mod tests { }; use hive_router_config::parse_yaml_config; use http::{HeaderMap, HeaderName, HeaderValue}; + use ntex::router::Path; use ntex_http::HeaderMap as NtexHeaderMap; fn header_name_owned(s: &str) -> HeaderName { @@ -77,6 +78,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -111,6 +113,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -158,6 +161,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -196,6 +200,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: Some("MyQuery"), query: "{ __typename }", @@ -230,6 +235,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -270,6 +276,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -314,6 +321,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -379,6 +387,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -443,6 +452,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -500,6 +510,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -558,6 +569,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", @@ -617,6 +629,7 @@ mod tests { method: &http::Method::POST, url: &"http://example.com".parse().unwrap(), headers: &client_headers, + path_params: &Path::default(), operation: OperationDetails { name: None, query: "{ __typename }", diff --git a/lib/executor/src/headers/request.rs b/lib/executor/src/headers/request.rs index 637ab0d58..ec22de509 100644 --- a/lib/executor/src/headers/request.rs +++ b/lib/executor/src/headers/request.rs @@ -1,12 +1,4 @@ -use std::collections::BTreeMap; - use http::HeaderMap; -use vrl::{ - compiler::TargetValue as VrlTargetValue, - core::Value as VrlValue, - prelude::{state::RuntimeState as VrlState, Context as VrlContext, TimeZone as VrlTimeZone}, - value::Secrets as VrlSecrets, -}; use crate::{ execution::client_request_details::ClientRequestDetails, @@ -19,6 +11,7 @@ use crate::{ }, sanitizer::{is_denied_header, is_never_join_header}, }, + utils::expression::execute_expression_with_value, }; pub fn modify_subgraph_request_headers( @@ -174,17 +167,7 @@ impl ApplyRequestHeader for RequestInsertExpression { if is_denied_header(&self.name) { return Ok(()); } - - let mut target = VrlTargetValue { - value: ctx.into(), - metadata: VrlValue::Object(BTreeMap::new()), - secrets: VrlSecrets::default(), - }; - - let mut state = VrlState::default(); - let timezone = VrlTimeZone::default(); - let mut ctx = VrlContext::new(&mut target, &mut state, &timezone); - let value = self.expression.resolve(&mut ctx).map_err(|err| { + let value = execute_expression_with_value(&self.expression, ctx.into()).map_err(|err| { HeaderRuleRuntimeError::new_expression_evaluation(self.name.to_string(), Box::new(err)) })?; diff --git a/lib/executor/src/headers/response.rs b/lib/executor/src/headers/response.rs index 6a5c34444..9cf45a768 100644 --- a/lib/executor/src/headers/response.rs +++ b/lib/executor/src/headers/response.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, iter::once}; +use std::iter::once; use crate::{ execution::client_request_details::ClientRequestDetails, @@ -13,16 +13,11 @@ use crate::{ }, sanitizer::is_denied_header, }, + utils::expression::execute_expression_with_value, }; use super::sanitizer::is_never_join_header; use http::{header::InvalidHeaderValue, HeaderMap, HeaderName, HeaderValue}; -use vrl::{ - compiler::TargetValue as VrlTargetValue, - core::Value as VrlValue, - prelude::{state::RuntimeState as VrlState, Context as VrlContext, TimeZone as VrlTimeZone}, - value::Secrets as VrlSecrets, -}; pub fn apply_subgraph_response_headers( header_rule_plan: &HeaderRulesPlan, @@ -194,20 +189,9 @@ impl ApplyResponseHeader for ResponseInsertExpression { if is_denied_header(&self.name) { return Ok(()); } - - let mut target = VrlTargetValue { - value: ctx.into(), - metadata: VrlValue::Object(BTreeMap::new()), - secrets: VrlSecrets::default(), - }; - - let mut state = VrlState::default(); - let timezone = VrlTimeZone::default(); - let mut ctx = VrlContext::new(&mut target, &mut state, &timezone); - let value = self.expression.resolve(&mut ctx).map_err(|err| { - HeaderRuleRuntimeError::ExpressionEvaluation(self.name.to_string(), Box::new(err)) + let value = execute_expression_with_value(&self.expression, ctx.into()).map_err(|err| { + HeaderRuleRuntimeError::new_expression_evaluation(self.name.to_string(), Box::new(err)) })?; - if let Some(header_value) = vrl_value_to_header_value(value) { let strategy = if is_never_join_header(&self.name) { HeaderAggregationStrategy::Append diff --git a/lib/executor/src/utils/expression.rs b/lib/executor/src/utils/expression.rs new file mode 100644 index 000000000..faee2dc3b --- /dev/null +++ b/lib/executor/src/utils/expression.rs @@ -0,0 +1,49 @@ +use once_cell::sync::Lazy; +use std::collections::BTreeMap; +use vrl::{ + compiler::{compile as vrl_compile, Program as VrlProgram, TargetValue as VrlTargetValue}, + core::Value as VrlValue, + prelude::{ + state::RuntimeState as VrlState, Context as VrlContext, ExpressionError, Function, + TimeZone as VrlTimeZone, + }, + stdlib::all as vrl_build_functions, + value::Secrets as VrlSecrets, +}; + +static VRL_FUNCTIONS: Lazy>> = Lazy::new(vrl_build_functions); +static VRL_TIMEZONE: Lazy = Lazy::new(VrlTimeZone::default); + +pub fn compile_expression( + expression: &str, + functions: Option<&[Box]>, +) -> Result { + let functions = functions.unwrap_or(&VRL_FUNCTIONS); + + let compilation_result = vrl_compile(expression, functions).map_err(|diagnostics| { + diagnostics + .errors() + .iter() + .map(|d| format!("{}: {}", d.code, d.message)) + .collect::>() + .join(", ") + })?; + + Ok(compilation_result.program) +} + +pub fn execute_expression_with_value( + program: &VrlProgram, + value: VrlValue, +) -> Result { + let mut target = VrlTargetValue { + value, + metadata: VrlValue::Object(BTreeMap::new()), + secrets: VrlSecrets::default(), + }; + + let mut state = VrlState::default(); + let mut ctx = VrlContext::new(&mut target, &mut state, &VRL_TIMEZONE); + + program.resolve(&mut ctx) +} diff --git a/lib/executor/src/utils/mod.rs b/lib/executor/src/utils/mod.rs index fc4226984..0461bb8a8 100644 --- a/lib/executor/src/utils/mod.rs +++ b/lib/executor/src/utils/mod.rs @@ -1,3 +1,4 @@ pub mod consts; +pub mod expression; pub mod traverse; pub mod vrl; diff --git a/lib/router-config/src/env_overrides.rs b/lib/router-config/src/env_overrides.rs index f61012967..00ed40161 100644 --- a/lib/router-config/src/env_overrides.rs +++ b/lib/router-config/src/env_overrides.rs @@ -83,11 +83,19 @@ impl EnvVarOverrides { } if let Some(hive_console_cdn_endpoint) = self.hive_console_cdn_endpoint.take() { - config = config.set_override("supergraph.source", "hive")?; - config = config.set_override("supergraph.endpoint", hive_console_cdn_endpoint)?; + config = + config.set_override("supergraph.endpoint", hive_console_cdn_endpoint.clone())?; + + config = config.set_override( + "persisted_documents.source.hive.endpoint", + hive_console_cdn_endpoint, + )?; if let Some(hive_console_cdn_key) = self.hive_console_cdn_key.take() { - config = config.set_override("supergraph.key", hive_console_cdn_key)?; + config = config.set_override("supergraph.key", hive_console_cdn_key.clone())?; + + config = config + .set_override("persisted_documents.source.hive.key", hive_console_cdn_key)?; } else { return Err(EnvVarOverridesError::MissingRequiredEnvVar("HIVE_CDN_KEY")); } diff --git a/lib/router-config/src/http_server.rs b/lib/router-config/src/http_server.rs index c4bf9048e..351bbf9b3 100644 --- a/lib/router-config/src/http_server.rs +++ b/lib/router-config/src/http_server.rs @@ -17,6 +17,10 @@ pub struct HttpServerConfig { /// If you are running the router inside a Docker container, please ensure that the port is exposed correctly using `-p :` flag. #[serde(default = "http_server_port_default")] port: u16, + + #[serde(default = "graphql_endpoint_default")] + // The GraphQL endpoint path. + pub graphql_endpoint: String, } impl Default for HttpServerConfig { @@ -24,6 +28,7 @@ impl Default for HttpServerConfig { Self { host: http_server_host_default(), port: http_server_port_default(), + graphql_endpoint: graphql_endpoint_default(), } } } @@ -36,6 +41,10 @@ fn http_server_port_default() -> u16 { 4000 } +fn graphql_endpoint_default() -> String { + "/graphql".to_string() +} + impl HttpServerConfig { pub fn address(&self) -> String { format!("{}:{}", self.host, self.port) diff --git a/lib/router-config/src/lib.rs b/lib/router-config/src/lib.rs index 537244c9e..0d6c835a2 100644 --- a/lib/router-config/src/lib.rs +++ b/lib/router-config/src/lib.rs @@ -8,6 +8,7 @@ pub mod jwt_auth; pub mod log; pub mod override_labels; pub mod override_subgraph_urls; +pub mod persisted_documents; pub mod primitives; pub mod query_planner; pub mod supergraph; @@ -26,10 +27,11 @@ use crate::{ http_server::HttpServerConfig, log::LoggingConfig, override_labels::OverrideLabelsConfig, + persisted_documents::PersistedDocumentsConfig, primitives::file_path::with_start_path, query_planner::QueryPlannerConfig, supergraph::SupergraphSource, - traffic_shaping::TrafficShapingExecutorConfig, + traffic_shaping::TrafficShapingConfig, }; #[derive(Debug, Deserialize, Serialize, JsonSchema)] @@ -62,9 +64,9 @@ pub struct HiveRouterConfig { #[serde(default)] pub http: HttpServerConfig, - /// Configuration for the traffic-shaper executor. Use these configurations to control how requests are being executed to subgraphs. + /// Configuration for the traffic-shaping of the executor. Use these configurations to control how requests are being executed to subgraphs. #[serde(default)] - pub traffic_shaping: TrafficShapingExecutorConfig, + pub traffic_shaping: TrafficShapingConfig, /// Configuration for the headers. #[serde(default)] @@ -92,6 +94,10 @@ pub struct HiveRouterConfig { /// Configuration for overriding labels. #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub override_labels: OverrideLabelsConfig, + + /// Configuration for persisted operations + #[serde(default, skip_serializing_if = "PersistedDocumentsConfig::is_disabled")] + pub persisted_documents: PersistedDocumentsConfig, } #[derive(Debug, thiserror::Error)] diff --git a/lib/router-config/src/persisted_documents.rs b/lib/router-config/src/persisted_documents.rs new file mode 100644 index 000000000..b92dae292 --- /dev/null +++ b/lib/router-config/src/persisted_documents.rs @@ -0,0 +1,162 @@ +use std::time::Duration; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::primitives::file_path::FilePath; + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] +#[serde(deny_unknown_fields)] +pub struct PersistedDocumentsConfig { + #[serde(default = "default_enabled")] + /// Whether persisted operations are enabled. + pub enabled: bool, + + /// Whether to allow arbitrary operations that are not persisted. + #[serde(default = "default_allow_arbitrary_operations")] + pub allow_arbitrary_operations: BoolOrExpression, + + /// The source of persisted documents. + #[serde(default = "default_source")] + pub source: PersistedDocumentsSource, + + /// The specification to extract persisted operations. + #[serde(default = "default_spec")] + pub spec: PersistedDocumentsSpec, +} + +impl PersistedDocumentsConfig { + pub fn is_disabled(&self) -> bool { + !self.enabled + } +} + +impl Default for PersistedDocumentsConfig { + fn default() -> Self { + Self { + enabled: default_enabled(), + allow_arbitrary_operations: default_allow_arbitrary_operations(), + source: default_source(), + spec: default_spec(), + } + } +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] +#[serde(deny_unknown_fields)] +pub enum PersistedDocumentsSource { + #[serde(rename = "file")] + File { + /// The path to the file containing persisted operations. + path: FilePath, + }, + #[serde(rename = "hive")] + HiveConsole { + /// The CDN endpoint from Hive Console target. + /// + /// Can also be set using the `HIVE_CDN_ENDPOINT` environment variable. + endpoint: String, + /// The CDN Access Token with from the Hive Console target. + /// + /// Can also be set using the `HIVE_CDN_KEY` environment variable. + key: String, + /// Request timeout for the Hive Console CDN requests. + #[serde( + default = "default_hive_request_timeout", + deserialize_with = "humantime_serde::deserialize", + serialize_with = "humantime_serde::serialize" + )] + request_timeout: Duration, + /// Connection timeout for the Hive Console CDN requests. + #[serde( + default = "default_hive_connect_timeout", + deserialize_with = "humantime_serde::deserialize", + serialize_with = "humantime_serde::serialize" + )] + connect_timeout: Duration, + /// Interval at which the Hive Console should be retried upon failure. + /// + /// By default, an exponential backoff retry policy is used, with 3 attempts. + #[serde(default = "default_hive_retry_count")] + retry_count: u32, + /// Accept invalid SSL certificates + /// default: false + #[serde(default = "default_accept_invalid_certs")] + accept_invalid_certs: bool, + + /// Configuration for the size of the in-memory caching of persisted documents. + #[serde(default = "default_cache_size")] + cache_size: u64, + }, +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] +#[serde(deny_unknown_fields, rename_all = "snake_case")] +pub enum PersistedDocumentsSpec { + /// Hive's persisted documents specification. + /// Expects the document ID to be found in the `documentId` field of the request's extra parameters. + /// This is the default specification. + Hive, + /// Apollo's persisted documents specification. + /// Expects the document ID to be found in the `extensions.persistedQuery.sha256Hash` field of the request's extra parameters. + Apollo, + /// Relay's persisted documents specification. + /// Expects the document ID to be found in the `doc_id` field of the request's extra parameters. + Relay, + /// A custom VRL expression to extract the persisted document ID from the request. + /// The expression should evaluate to a string representing the document ID. + Expression(String), +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] +#[serde(untagged)] +pub enum BoolOrExpression { + /// A static boolean value. + Bool(bool), + /// A dynamic value computed by a VRL expression. + Expression { expression: String }, +} + +fn default_enabled() -> bool { + false +} + +fn default_allow_arbitrary_operations() -> BoolOrExpression { + BoolOrExpression::Bool(false) +} + +fn default_source() -> PersistedDocumentsSource { + PersistedDocumentsSource::HiveConsole { + endpoint: "".into(), + key: "".into(), + request_timeout: default_hive_request_timeout(), + connect_timeout: default_hive_connect_timeout(), + retry_count: default_hive_retry_count(), + accept_invalid_certs: default_accept_invalid_certs(), + cache_size: default_cache_size(), + } +} + +fn default_spec() -> PersistedDocumentsSpec { + PersistedDocumentsSpec::Hive +} + +fn default_hive_request_timeout() -> Duration { + Duration::from_secs(15) +} + +fn default_hive_connect_timeout() -> Duration { + Duration::from_secs(5) +} + +fn default_hive_retry_count() -> u32 { + 3 +} + +fn default_accept_invalid_certs() -> bool { + false +} + +fn default_cache_size() -> u64 { + 1000 +} diff --git a/lib/router-config/src/traffic_shaping.rs b/lib/router-config/src/traffic_shaping.rs index 02ed5ecdd..95d824910 100644 --- a/lib/router-config/src/traffic_shaping.rs +++ b/lib/router-config/src/traffic_shaping.rs @@ -1,16 +1,23 @@ +use std::time::Duration; + use schemars::JsonSchema; use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)] #[serde(deny_unknown_fields)] -pub struct TrafficShapingExecutorConfig { +pub struct TrafficShapingConfig { /// Limits the concurrent amount of requests/connections per host/subgraph. #[serde(default = "default_max_connections_per_host")] pub max_connections_per_host: usize, /// Timeout for idle sockets being kept-alive. - #[serde(default = "default_pool_idle_timeout_seconds")] - pub pool_idle_timeout_seconds: u64, + #[serde( + default = "default_pool_idle_timeout", + deserialize_with = "humantime_serde::deserialize", + serialize_with = "humantime_serde::serialize" + )] + #[schemars(with = "String")] + pub pool_idle_timeout: Duration, /// Enables/disables request deduplication to subgraphs. /// @@ -20,11 +27,11 @@ pub struct TrafficShapingExecutorConfig { pub dedupe_enabled: bool, } -impl Default for TrafficShapingExecutorConfig { +impl Default for TrafficShapingConfig { fn default() -> Self { Self { max_connections_per_host: default_max_connections_per_host(), - pool_idle_timeout_seconds: default_pool_idle_timeout_seconds(), + pool_idle_timeout: default_pool_idle_timeout(), dedupe_enabled: default_dedupe_enabled(), } } @@ -34,8 +41,8 @@ fn default_max_connections_per_host() -> usize { 100 } -fn default_pool_idle_timeout_seconds() -> u64 { - 50 +fn default_pool_idle_timeout() -> Duration { + Duration::from_secs(50) } fn default_dedupe_enabled() -> bool {