Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 126 additions & 7 deletions rs/boundary_node/ic_boundary/src/http/middleware/process.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{sync::Arc, time::Duration};

use axum::body::HttpBody;
use axum::{Extension, body::Body, extract::Request, middleware::Next, response::IntoResponse};
use bytes::Bytes;
use candid::{Decode, Principal};
Expand All @@ -23,7 +24,11 @@ const METHOD_HTTP: &str = "http_request";
const HEADERS_HIDE_HTTP_REQUEST: [&str; 4] =
["x-real-ip", "x-forwarded-for", "x-request-id", "user-agent"];

// This is the subset of the request fields
/// This is the subset of the request fields.
///
/// TODO: add sanity checks for Blob fields so that
/// we don't process too big forged requests.
/// E.g. the nonce is probably fixed-length etc.
#[derive(Clone, Debug, Deserialize, Serialize)]
struct ICRequestContent {
sender: Principal,
Expand Down Expand Up @@ -60,14 +65,44 @@ where
Ok(s)
}

// Middleware: preprocess the request before handing it over to handlers
/// Checks if given paths are cacheable
pub(crate) fn should_cache_paths(paths: &[Vec<Blob>]) -> bool {
// Check that we have correct lengths
if paths.len() != 2 || paths.iter().any(|x| x.len() != 2) {
return false;
}

// Check that 2nd labels are short enough to be Principals
if !paths
.iter()
.all(|x| x[1].0.len() <= Principal::MAX_LENGTH_IN_BYTES)
{
return false;
}

// Check that we have a correct combination of 1st labels.
// This looks a bit ugly, but efficient.
[
(&b"canister_ranges"[..], &b"subnet"[..]),
(&b"subnet"[..], &b"canister_ranges"[..]),
]
.contains(&(&paths[0][0].0[..], &paths[1][0].0[..]))
}

/// Middleware: preprocess the request before handing it over to handlers
pub async fn preprocess_request(
Extension(request_type): Extension<RequestType>,
request: Request,
next: Next,
) -> Result<impl IntoResponse, ApiError> {
// Consume body
let (parts, body) = request.into_parts();

// Early check for the body size to avoid streaming too big requests
if body.size_hint().exact() > Some(MAX_REQUEST_BODY_SIZE as u64) {
return Err(ErrorCause::PayloadTooLarge(MAX_REQUEST_BODY_SIZE).into());
}

let body = buffer_body_to_bytes(body, MAX_REQUEST_BODY_SIZE, Duration::from_secs(60)).await?;

// Parse the request body
Expand Down Expand Up @@ -100,6 +135,16 @@ pub async fn preprocess_request(
(_, arg) => (arg, None),
};

// Check if it's a subnet read state request & it's eligible for caching
let read_state_paths = if request_type.is_read_state_subnet()
&& let Some(x) = content.paths
&& should_cache_paths(&x)
{
Some(x.into())
} else {
None
};

// Construct the context
let ctx = RequestContext {
request_type,
Expand All @@ -111,11 +156,7 @@ pub async fn preprocess_request(
arg: arg.map(|x| x.0),
nonce: content.nonce.map(|x| x.0),
http_request,
read_state_paths: content.paths.map(|p| {
p.into_iter()
.map(|v| v.into_iter().map(|b| b.0).collect())
.collect()
}),
read_state_paths,
};

let ctx = Arc::new(ctx);
Expand Down Expand Up @@ -244,6 +285,7 @@ pub async fn postprocess_response(request: Request, next: Next) -> impl IntoResp
mod tests {
use super::*;
use candid::Principal;
use ic_bn_lib_common::principal;
use serde_cbor::Value;
use std::collections::BTreeMap;

Expand Down Expand Up @@ -342,4 +384,81 @@ mod tests {
let content: ICRequestContent = serde_cbor::from_slice(&data).unwrap();
assert!(content.method_name.is_none());
}

#[test]
fn test_should_cache_paths() {
let subnet_id = principal!("aaaaa-aa").as_slice().to_vec();

let should_cache_paths_wrapper = |paths: &[Vec<Vec<u8>>]| -> bool {
let paths = paths
.iter()
.map(|x| x.iter().map(|x| Blob(x.clone())).collect())
.collect::<Vec<_>>();
should_cache_paths(&paths)
};

// non-cacheable
assert!(!should_cache_paths_wrapper(&[vec![]]));
assert!(!should_cache_paths_wrapper(&[vec![
b"canister_ranges".to_vec()
]]));
assert!(!should_cache_paths_wrapper(&[vec![b"subnet".to_vec()]]));
assert!(!should_cache_paths_wrapper(&[
vec![b"subnet".to_vec()],
vec![b"canister_ranges".to_vec()]
]));

assert!(!should_cache_paths_wrapper(&[
vec![b"subnet".to_vec(), subnet_id.clone()],
vec![b"canister_ranges".to_vec(), subnet_id.clone()],
vec![b"some_other".to_vec(), b"label".to_vec()]
]));

assert!(!should_cache_paths_wrapper(&[
vec![b"subnet".to_vec(), subnet_id.clone(), subnet_id.clone()],
vec![
b"canister_ranges".to_vec(),
subnet_id.clone(),
subnet_id.clone()
]
]));

assert!(!should_cache_paths_wrapper(&[
vec![b"subnet".to_vec(), subnet_id.clone(),],
vec![b"subnet".to_vec(), subnet_id.clone(),]
]));

assert!(!should_cache_paths_wrapper(&[
vec![b"canister_ranges".to_vec(), subnet_id.clone(),],
vec![b"canister_ranges".to_vec(), subnet_id.clone(),]
]));

// too long slices for a principal
assert!(!should_cache_paths_wrapper(&[
vec![
b"subnet".to_vec(),
b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_vec(),
],
vec![b"canister_ranges".to_vec(), subnet_id.clone()]
]));

assert!(!should_cache_paths_wrapper(&[
vec![b"subnet".to_vec(), subnet_id.clone()],
vec![
b"canister_ranges".to_vec(),
b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_vec()
]
]));

// cacheable
assert!(should_cache_paths_wrapper(&[
vec![b"subnet".to_vec(), subnet_id.clone()],
vec![b"canister_ranges".to_vec(), subnet_id.clone()]
]));

assert!(should_cache_paths_wrapper(&[
vec![b"canister_ranges".to_vec(), subnet_id.clone()],
vec![b"subnet".to_vec(), subnet_id]
]));
}
}
Loading
Loading