Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions bin/router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,5 @@ tokio-util = "0.7.16"
cookie = "0.18.1"
regex-automata = "0.4.10"
arc-swap = "1.7.1"
ntex-util = "2.15.0"
ntex-service = "3.5.0"
21 changes: 11 additions & 10 deletions bin/router/src/jwt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,27 @@ impl JwtAuthRuntime {
Ok(token_data)
}

pub fn validate_request(&self, request: &mut HttpRequest) -> Result<(), JwtError> {
pub fn validate_request(
&self,
request: &HttpRequest,
) -> Result<Option<JwtRequestContext>, JwtError> {
let valid_jwks = self.jwks.all();

match self.authenticate(&valid_jwks, request) {
Ok((token_payload, maybe_token_prefix, token)) => {
request.extensions_mut().insert(JwtRequestContext {
token_payload,
token_raw: token,
token_prefix: maybe_token_prefix,
});
}
Ok((token_payload, maybe_token_prefix, token)) => Ok(Some(JwtRequestContext {
token_payload,
token_raw: token,
token_prefix: maybe_token_prefix,
})),
Err(e) => {
warn!("jwt token error: {:?}", e);

if self.config.require_authentication.is_some_and(|v| v) {
return Err(e);
}

Ok(None)
}
}

Ok(())
}
}
42 changes: 31 additions & 11 deletions bin/router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod http_utils;
mod jwt;
mod logger;
mod pipeline;
mod plugins;
mod schema_state;
mod shared_state;
mod supergraph;
Expand All @@ -19,7 +20,12 @@ use crate::{
},
jwt::JwtAuthRuntime,
logger::configure_logging,
pipeline::graphql_request_handler,
pipeline::{
error::PipelineError,
graphql_request_handler,
header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR},
},
plugins::plugins_service::PluginService,
};

pub use crate::{schema_state::SchemaState, shared_state::RouterSharedState};
Expand All @@ -33,7 +39,7 @@ use ntex::{
use tracing::{info, warn};

async fn graphql_endpoint_handler(
mut request: HttpRequest,
req: HttpRequest,
body_bytes: Bytes,
schema_state: web::types::State<Arc<SchemaState>>,
app_state: web::types::State<Arc<RouterSharedState>>,
Expand All @@ -45,26 +51,32 @@ async fn graphql_endpoint_handler(
if let Some(early_response) = app_state
.cors_runtime
.as_ref()
.and_then(|cors| cors.get_early_response(&request))
.and_then(|cors| cors.get_early_response(&req))
{
return early_response;
}

let mut res = graphql_request_handler(
&mut request,
let accept_ok = !req.accepts_content_type(&APPLICATION_GRAPHQL_RESPONSE_JSON_STR);

let mut response = match graphql_request_handler(
&req,
body_bytes,
supergraph,
app_state.get_ref(),
schema_state.get_ref(),
)
.await;
.await
{
Ok(response_with_req) => response_with_req,
Err(error) => return PipelineError { accept_ok, error }.into(),
};

// Apply CORS headers to the final response if CORS is configured.
if let Some(cors) = app_state.cors_runtime.as_ref() {
cors.set_headers(&request, res.headers_mut());
cors.set_headers(&req, response.headers_mut());
}

res
response
} else {
warn!("No supergraph available yet, unable to process request");

Expand All @@ -86,6 +98,7 @@ pub async fn router_entrypoint() -> Result<(), Box<dyn std::error::Error>> {

let maybe_error = web::HttpServer::new(move || {
web::App::new()
.wrap(PluginService)
.state(shared_state.clone())
.state(schema_state.clone())
.configure(configure_ntex_app)
Expand All @@ -112,10 +125,17 @@ pub async fn configure_app_from_config(
};

let router_config_arc = Arc::new(router_config);
let schema_state =
SchemaState::new_from_config(bg_tasks_manager, router_config_arc.clone()).await?;
let shared_state = Arc::new(RouterSharedState::new(
router_config_arc.clone(),
jwt_runtime,
)?);
let schema_state = SchemaState::new_from_config(
bg_tasks_manager,
router_config_arc.clone(),
shared_state.clone(),
)
.await?;
let schema_state_arc = Arc::new(schema_state);
let shared_state = Arc::new(RouterSharedState::new(router_config_arc, jwt_runtime)?);

Ok((shared_state, schema_state_arc))
}
Expand Down
19 changes: 9 additions & 10 deletions bin/router/src/pipeline/coerce_variables.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use std::collections::HashMap;
use std::sync::Arc;

use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams;
use hive_router_plan_executor::hooks::on_supergraph_load::SupergraphData;
use hive_router_plan_executor::variables::collect_variables;
use hive_router_query_planner::state::supergraph_state::OperationKind;
use http::Method;
use ntex::web::HttpRequest;
use sonic_rs::Value;
use tracing::{error, trace, warn};

use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant};
use crate::pipeline::execution_request::ExecutionRequest;
use crate::pipeline::error::PipelineErrorVariant;
use crate::pipeline::normalize::GraphQLNormalizationPayload;
use crate::schema_state::SupergraphData;

#[derive(Clone, Debug)]
pub struct CoerceVariablesPayload {
Expand All @@ -22,22 +21,22 @@ pub struct CoerceVariablesPayload {
pub fn coerce_request_variables(
req: &HttpRequest,
supergraph: &SupergraphData,
execution_params: &mut ExecutionRequest,
normalized_operation: &Arc<GraphQLNormalizationPayload>,
) -> Result<CoerceVariablesPayload, PipelineError> {
graphql_params: &mut GraphQLParams,
normalized_operation: &GraphQLNormalizationPayload,
) -> Result<CoerceVariablesPayload, PipelineErrorVariant> {
if req.method() == Method::GET {
if let Some(OperationKind::Mutation) =
normalized_operation.operation_for_plan.operation_kind
{
error!("Mutation is not allowed over GET, stopping");

return Err(req.new_pipeline_error(PipelineErrorVariant::MutationNotAllowedOverHttpGet));
return Err(PipelineErrorVariant::MutationNotAllowedOverHttpGet);
}
}

match collect_variables(
&normalized_operation.operation_for_plan,
&mut execution_params.variables,
&mut graphql_params.variables,
&supergraph.metadata,
) {
Ok(values) => {
Expand All @@ -55,7 +54,7 @@ pub fn coerce_request_variables(
"failed to collect variables from incoming request: {}",
err_msg
);
Err(req.new_pipeline_error(PipelineErrorVariant::VariablesCoercionError(err_msg)))
Err(PipelineErrorVariant::VariablesCoercionError(err_msg))
}
}
}
8 changes: 4 additions & 4 deletions bin/router/src/pipeline/csrf_prevention.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use hive_router_config::csrf::CSRFPreventionConfig;
use ntex::web::HttpRequest;

use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant};
use crate::pipeline::error::PipelineErrorVariant;

// NON_PREFLIGHTED_CONTENT_TYPES are content types that do not require a preflight
// OPTIONS request. These are content types that are considered "simple" by the CORS
Expand All @@ -15,9 +15,9 @@ const NON_PREFLIGHTED_CONTENT_TYPES: [&str; 3] = [

#[inline]
pub fn perform_csrf_prevention(
req: &mut HttpRequest,
req: &HttpRequest,
csrf_config: &CSRFPreventionConfig,
) -> Result<(), PipelineError> {
) -> Result<(), PipelineErrorVariant> {
// If CSRF prevention is not configured or disabled, skip the checks.
if !csrf_config.enabled || csrf_config.required_headers.is_empty() {
return Ok(());
Expand All @@ -39,7 +39,7 @@ pub fn perform_csrf_prevention(
if has_required_header {
Ok(())
} else {
Err(req.new_pipeline_error(PipelineErrorVariant::CsrfPreventionFailed))
Err(PipelineErrorVariant::CsrfPreventionFailed)
}
}

Expand Down
Loading