diff --git a/rs/boundary_node/ic_boundary/src/http/middleware/process.rs b/rs/boundary_node/ic_boundary/src/http/middleware/process.rs index 60963b86db7d..7d1f70696071 100644 --- a/rs/boundary_node/ic_boundary/src/http/middleware/process.rs +++ b/rs/boundary_node/ic_boundary/src/http/middleware/process.rs @@ -1,5 +1,6 @@ use std::{sync::Arc, time::Duration}; +use axum::body::HttpBody; use axum::{Extension, body::Body, extract::Request, middleware::Next, response::IntoResponse}; use bytes::Bytes; use candid::{Decode, Principal}; @@ -9,6 +10,7 @@ use ic_types::messages::Blob; use serde::de::Error as SerdeDeError; use serde::{Deserialize, Deserializer, Serialize}; +use crate::routes::ReadStatePaths; use crate::{ core::{MAX_REQUEST_BODY_SIZE, decoder_config}, errors::{ApiError, ErrorCause, buffer_body_to_bytes}, @@ -23,7 +25,11 @@ const METHOD_HTTP: &str = "http_request"; const HEADERS_HIDE_HTTP_REQUEST: [&str; 4] = ["x-real-ip", "x-forwarded-for", "x-request-id", "user-agent"]; -// This is the subset of the request fields +/// This is the subset of the request fields. +/// +/// TODO: add sanity checks for Blob fields so that +/// we don't process too big forged requests. +/// E.g. the nonce is probably fixed-length etc. #[derive(Clone, Debug, Deserialize, Serialize)] struct ICRequestContent { sender: Principal, @@ -41,7 +47,7 @@ pub struct ICRequestEnvelope { content: ICRequestContent, } -// Restrict the method name to its max length +/// Restrict the method name to its max length pub const MAX_METHOD_NAME_LENGTH: usize = 20_000; fn check_method_name_length<'de, D>(deserializer: D) -> Result, D::Error> @@ -60,14 +66,44 @@ where Ok(s) } -// Middleware: preprocess the request before handing it over to handlers +/// Checks if given paths are cacheable +pub(crate) fn should_cache_paths(paths: &[Vec]) -> bool { + // Check that we have correct lengths + if paths.len() != 2 || paths.iter().any(|x| x.len() != 2) { + return false; + } + + // Check that 2nd labels are short enough to be Principals + if !paths + .iter() + .all(|x| x[1].0.len() <= Principal::MAX_LENGTH_IN_BYTES) + { + return false; + } + + // Check that we have a correct combination of 1st labels. + // This looks a bit ugly, but efficient. + [ + (&b"canister_ranges"[..], &b"subnet"[..]), + (&b"subnet"[..], &b"canister_ranges"[..]), + ] + .contains(&(&paths[0][0].0[..], &paths[1][0].0[..])) +} + +/// Middleware: preprocess the request before handing it over to handlers pub async fn preprocess_request( Extension(request_type): Extension, request: Request, next: Next, ) -> Result { // Consume body - let (parts, body) = request.into_parts(); + let (mut parts, body) = request.into_parts(); + + // Early check for the body size to avoid streaming too big requests + if body.size_hint().exact() > Some(MAX_REQUEST_BODY_SIZE as u64) { + return Err(ErrorCause::PayloadTooLarge(MAX_REQUEST_BODY_SIZE).into()); + } + let body = buffer_body_to_bytes(body, MAX_REQUEST_BODY_SIZE, Duration::from_secs(60)).await?; // Parse the request body @@ -100,6 +136,15 @@ pub async fn preprocess_request( (_, arg) => (arg, None), }; + // Check if it's a subnet read state request & it's eligible for caching. + // If it is - insert the paths into extensions. + if request_type.is_read_state_subnet() + && let Some(x) = content.paths + && should_cache_paths(&x) + { + parts.extensions.insert(ReadStatePaths::from(x)); + } + // Construct the context let ctx = RequestContext { request_type, @@ -111,11 +156,6 @@ pub async fn preprocess_request( arg: arg.map(|x| x.0), nonce: content.nonce.map(|x| x.0), http_request, - read_state_paths: content.paths.map(|p| { - p.into_iter() - .map(|v| v.into_iter().map(|b| b.0).collect()) - .collect() - }), }; let ctx = Arc::new(ctx); @@ -244,6 +284,7 @@ pub async fn postprocess_response(request: Request, next: Next) -> impl IntoResp mod tests { use super::*; use candid::Principal; + use ic_bn_lib_common::principal; use serde_cbor::Value; use std::collections::BTreeMap; @@ -342,4 +383,81 @@ mod tests { let content: ICRequestContent = serde_cbor::from_slice(&data).unwrap(); assert!(content.method_name.is_none()); } + + #[test] + fn test_should_cache_paths() { + let subnet_id = principal!("aaaaa-aa").as_slice().to_vec(); + + // Wraps with Blob + let wrapper = |paths: &[Vec>]| -> bool { + let paths = paths + .iter() + .map(|x| x.iter().map(|x| Blob(x.clone())).collect()) + .collect::>(); + + should_cache_paths(&paths) + }; + + // non-cacheable + assert!(!wrapper(&[vec![]])); + assert!(!wrapper(&[vec![b"canister_ranges".to_vec()]])); + assert!(!wrapper(&[vec![b"subnet".to_vec()]])); + assert!(!wrapper(&[ + vec![b"subnet".to_vec()], + vec![b"canister_ranges".to_vec()] + ])); + + assert!(!wrapper(&[ + vec![b"subnet".to_vec(), subnet_id.clone()], + vec![b"canister_ranges".to_vec(), subnet_id.clone()], + vec![b"some_other".to_vec(), b"label".to_vec()] + ])); + + assert!(!wrapper(&[ + vec![b"subnet".to_vec(), subnet_id.clone(), subnet_id.clone()], + vec![ + b"canister_ranges".to_vec(), + subnet_id.clone(), + subnet_id.clone() + ] + ])); + + assert!(!wrapper(&[ + vec![b"subnet".to_vec(), subnet_id.clone(),], + vec![b"subnet".to_vec(), subnet_id.clone(),] + ])); + + assert!(!wrapper(&[ + vec![b"canister_ranges".to_vec(), subnet_id.clone(),], + vec![b"canister_ranges".to_vec(), subnet_id.clone(),] + ])); + + // too long slices for a principal + assert!(!wrapper(&[ + vec![ + b"subnet".to_vec(), + b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_vec(), + ], + vec![b"canister_ranges".to_vec(), subnet_id.clone()] + ])); + + assert!(!wrapper(&[ + vec![b"subnet".to_vec(), subnet_id.clone()], + vec![ + b"canister_ranges".to_vec(), + b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_vec() + ] + ])); + + // cacheable + assert!(wrapper(&[ + vec![b"subnet".to_vec(), subnet_id.clone()], + vec![b"canister_ranges".to_vec(), subnet_id.clone()] + ])); + + assert!(wrapper(&[ + vec![b"canister_ranges".to_vec(), subnet_id.clone()], + vec![b"subnet".to_vec(), subnet_id] + ])); + } } diff --git a/rs/boundary_node/ic_boundary/src/http/middleware/subnet_read_state_cache.rs b/rs/boundary_node/ic_boundary/src/http/middleware/subnet_read_state_cache.rs index d0a789e33ae3..509b9186a15c 100644 --- a/rs/boundary_node/ic_boundary/src/http/middleware/subnet_read_state_cache.rs +++ b/rs/boundary_node/ic_boundary/src/http/middleware/subnet_read_state_cache.rs @@ -1,7 +1,7 @@ use std::{mem::size_of, sync::Arc, time::Duration}; use axum::{ - body::Body, + body::{Body, HttpBody}, extract::{Request, State}, middleware::Next, response::{IntoResponse, Response}, @@ -17,24 +17,22 @@ use moka::sync::Cache; use crate::{ errors::{ApiError, buffer_body_to_bytes}, - routes::RequestContext, + routes::ReadStatePaths, }; -type ReadStateLabel = Vec; -type ReadStatePath = Vec; -type ReadStatePaths = Vec; - #[derive(Clone, Debug, PartialEq, Eq, Hash)] struct CacheKey { subnet_id: SubnetId, paths: ReadStatePaths, } -fn weigh_entry(_key: &CacheKey, value: &Response) -> u32 { +fn weigh_entry(key: &CacheKey, value: &Response) -> u32 { let size = size_of::() + + key.paths.len() + size_of::>() + calc_headers_size(value.headers()) + value.body().len(); + size as u32 } @@ -107,35 +105,19 @@ impl SubnetReadStateCacheState { } } -fn is_cacheable_path(path: &ReadStatePath) -> bool { - path.len() == 2 && (path[0] == b"canister_ranges" || path[0] == b"subnet") -} - -fn should_cache_paths(paths: &ReadStatePaths) -> bool { - !paths.is_empty() && paths.iter().all(is_cacheable_path) -} - -fn build_cache_key(subnet_id: SubnetId, paths: &ReadStatePaths) -> CacheKey { - let mut paths = paths.clone(); - paths.sort(); - CacheKey { subnet_id, paths } -} - pub async fn subnet_read_state_cache_middleware( State(state): State>, - request: Request, + mut request: Request, next: Next, ) -> Result { let subnet_id = request.extensions().get::().copied(); - let ctx = request.extensions().get::>().cloned(); - let paths = ctx.as_ref().and_then(|ctx| ctx.read_state_paths.as_ref()); + let paths = request.extensions_mut().remove::(); - let (subnet_id, paths) = match (&subnet_id, &paths) { - (Some(sid), Some(paths)) if should_cache_paths(paths) => (*sid, *paths), - _ => return Ok(next.run(request).await), + let (Some(subnet_id), Some(paths)) = (subnet_id, paths) else { + return Ok(next.run(request).await); }; - let cache_key = build_cache_key(subnet_id, paths); + let cache_key = CacheKey { subnet_id, paths }; if let Some(cached) = state.cache.get(&cache_key) { state.hits.inc(); @@ -147,19 +129,21 @@ pub async fn subnet_read_state_cache_middleware( let response = next.run(request).await; - if response.status().is_success() { - let (parts, body) = response.into_parts(); - let body_bytes = - buffer_body_to_bytes(body, state.max_item_size, state.body_timeout).await?; + // Return response as-is if it failed or the advertised body size is too big + if !response.status().is_success() + || response.body().size_hint().exact() > Some(state.max_item_size as u64) + { + return Ok(response); + } - let cached = Response::from_parts(parts, body_bytes); - state.cache.insert(cache_key, cached.clone()); - state.update_gauges(); + let (parts, body) = response.into_parts(); + let body_bytes = buffer_body_to_bytes(body, state.max_item_size, state.body_timeout).await?; - Ok(cached.map(Body::from)) - } else { - Ok(response) - } + let cached = Response::from_parts(parts, body_bytes); + state.cache.insert(cache_key, cached.clone()); + state.update_gauges(); + + Ok(cached.map(Body::from)) } #[cfg(test)] @@ -170,26 +154,29 @@ mod tests { use axum::{Router, body::Body, http::Request, middleware, routing::post}; use http::StatusCode; - use ic_types::PrincipalId; + use ic_bn_lib_common::principal; + use ic_types::{PrincipalId, messages::Blob}; use tower::Service; - use crate::{http::RequestType, routes::RequestContext}; + use crate::http::middleware::process::should_cache_paths; const DEFAULT_TTL: Duration = Duration::from_secs(60); const DEFAULT_CACHE_SIZE: u64 = 1024 * 1024; const DEFAULT_MAX_ITEM_SIZE: usize = 1024 * 1024; const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(10); - fn make_request(subnet_id: SubnetId, paths: ReadStatePaths) -> Request { - let ctx = Arc::new(RequestContext { - request_type: RequestType::ReadStateSubnetV2, - read_state_paths: Some(paths), - ..Default::default() - }); + fn make_request(subnet_id: SubnetId, paths: Vec>>) -> Request { + let paths = paths + .iter() + .map(|x| x.iter().map(|x| Blob(x.clone())).collect()) + .collect::>(); let mut req = Request::post("/").body(Body::from("body")).unwrap(); + if should_cache_paths(&paths) { + req.extensions_mut().insert(ReadStatePaths::from(paths)); + } + req.extensions_mut().insert(subnet_id); - req.extensions_mut().insert(ctx); req } @@ -221,10 +208,12 @@ mod tests { (app, state) } - fn cacheable_paths() -> ReadStatePaths { + fn cacheable_paths() -> Vec>> { + let subnet_id = principal!("aaaaa-aa").as_slice().to_vec(); + vec![ - vec![b"canister_ranges".to_vec(), b"subnet_id_1".to_vec()], - vec![b"subnet".to_vec(), b"subnet_id_1".to_vec()], + vec![b"canister_ranges".to_vec(), subnet_id.clone()], + vec![b"subnet".to_vec(), subnet_id], ] } @@ -262,6 +251,9 @@ mod tests { #[tokio::test] async fn test_different_paths_are_separate_entries() { + let subnet_id_1 = test_subnet_id(0).get().as_slice().to_vec(); + let subnet_id_2 = test_subnet_id(1).get().as_slice().to_vec(); + let (mut app, state) = setup_app(DEFAULT_TTL, DEFAULT_CACHE_SIZE); let subnet = test_subnet_id(1); @@ -269,7 +261,10 @@ mod tests { // Request with canister_ranges for subnet A let req = make_request( subnet, - vec![vec![b"canister_ranges".to_vec(), b"aaa".to_vec()]], + vec![ + vec![b"subnet".to_vec(), subnet_id_1.clone()], + vec![b"canister_ranges".to_vec(), subnet_id_1], + ], ); app.call(req).await.unwrap(); assert_eq!(state.misses.get(), 1); @@ -277,8 +272,12 @@ mod tests { // Request with canister_ranges for subnet B: different paths = cache miss let req = make_request( subnet, - vec![vec![b"canister_ranges".to_vec(), b"bbb".to_vec()]], + vec![ + vec![b"subnet".to_vec(), subnet_id_2.clone()], + vec![b"canister_ranges".to_vec(), subnet_id_2], + ], ); + app.call(req).await.unwrap(); assert_eq!(state.misses.get(), 2); assert_eq!(state.hits.get(), 0); @@ -306,10 +305,12 @@ mod tests { async fn test_path_order_does_not_matter() { let (mut app, state) = setup_app(DEFAULT_TTL, DEFAULT_CACHE_SIZE); - let subnet = test_subnet_id(1); + let subnet = test_subnet_id(0); + let subnet_id_1 = test_subnet_id(0).get().as_slice().to_vec(); + let subnet_id_2 = test_subnet_id(1).get().as_slice().to_vec(); - let path_a = vec![b"canister_ranges".to_vec(), b"id_a".to_vec()]; - let path_b = vec![b"subnet".to_vec(), b"id_b".to_vec()]; + let path_a = vec![b"canister_ranges".to_vec(), subnet_id_1]; + let path_b = vec![b"subnet".to_vec(), subnet_id_2]; // Paths in order [A, B] let req = make_request(subnet, vec![path_a.clone(), path_b.clone()]); @@ -326,16 +327,8 @@ mod tests { async fn test_no_paths_bypasses_cache() { let (mut app, state) = setup_app(DEFAULT_TTL, DEFAULT_CACHE_SIZE); - // Request with no paths in context - let ctx = Arc::new(RequestContext { - request_type: RequestType::ReadStateSubnetV2, - read_state_paths: None, - ..Default::default() - }); - let mut req = Request::post("/").body(Body::from("body")).unwrap(); req.extensions_mut().insert(test_subnet_id(1)); - req.extensions_mut().insert(ctx); let resp = app.call(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); @@ -369,10 +362,11 @@ mod tests { let (mut app, state) = setup_app(DEFAULT_TTL, DEFAULT_CACHE_SIZE); let subnet = test_subnet_id(1); + let subnet_id = subnet.get().as_slice().to_vec(); // Mix of cacheable (canister_ranges) and non-cacheable (time) let paths = vec![ - vec![b"canister_ranges".to_vec(), b"subnet_id".to_vec()], + vec![b"canister_ranges".to_vec(), subnet_id], vec![b"time".to_vec()], ]; diff --git a/rs/boundary_node/ic_boundary/src/http/mod.rs b/rs/boundary_node/ic_boundary/src/http/mod.rs index c234c2c41e29..940301fe1743 100644 --- a/rs/boundary_node/ic_boundary/src/http/mod.rs +++ b/rs/boundary_node/ic_boundary/src/http/mod.rs @@ -62,6 +62,14 @@ impl RequestType { pub const fn is_call(&self) -> bool { matches!(self, Self::CallV2 | Self::CallV3 | Self::CallV4) } + + pub const fn is_read_state(&self) -> bool { + matches!(self, Self::ReadStateV2 | Self::ReadStateV3) + } + + pub const fn is_read_state_subnet(&self) -> bool { + matches!(self, Self::ReadStateSubnetV2 | Self::ReadStateSubnetV3) + } } // Try to categorize the error that we got from Reqwest call diff --git a/rs/boundary_node/ic_boundary/src/routes.rs b/rs/boundary_node/ic_boundary/src/routes.rs index 578315aa3c1e..05141ae87d81 100644 --- a/rs/boundary_node/ic_boundary/src/routes.rs +++ b/rs/boundary_node/ic_boundary/src/routes.rs @@ -14,7 +14,10 @@ use axum::{ use candid::{CandidType, Principal}; use ic_bn_lib::http::proxy; use ic_bn_lib_common::traits::http::Client as HttpClient; -use ic_types::{CanisterId, SubnetId, messages::ReplicaHealthStatus}; +use ic_types::{ + CanisterId, SubnetId, + messages::{Blob, ReplicaHealthStatus}, +}; use serde::Deserialize; use url::Url; @@ -35,6 +38,31 @@ pub struct HttpRequest { pub body: Vec, } +/// Paths for a read state call. +/// To avoid multiple caching the paths must be sorted. +#[derive(Debug, Clone, Default, Hash, Eq, PartialEq)] +pub struct ReadStatePaths(Vec>>); + +impl ReadStatePaths { + /// Returns the combined length of all labels + pub fn len(&self) -> usize { + self.0.iter().flat_map(|x| x.iter().map(|x| x.len())).sum() + } +} + +impl From>> for ReadStatePaths { + fn from(paths: Vec>) -> Self { + let mut paths: Vec>> = paths + .into_iter() + .map(|x| x.into_iter().map(|x| x.0).collect()) + .collect(); + + paths.sort(); + + Self(paths) + } +} + /// Per-request information #[derive(Debug, Clone, Default)] pub struct RequestContext { @@ -49,11 +77,8 @@ pub struct RequestContext { pub ingress_expiry: Option, pub arg: Option>, - // Filled in when the inner request is HTTP + /// Filled in when the inner request is HTTP pub http_request: Option, - - // Filled in for read_state requests - pub read_state_paths: Option>>>, } impl RequestContext {