diff --git a/lib/rs/src/configuration.rs b/lib/rs/src/configuration.rs new file mode 100644 index 00000000000..0f786f4ee15 --- /dev/null +++ b/lib/rs/src/configuration.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Configuration for Thrift protocols. +#[derive(Debug, Clone)] +pub struct TConfiguration { + max_message_size: Option, + max_frame_size: Option, + max_recursion_depth: Option, + max_container_size: Option, + max_string_size: Option, +} + +impl TConfiguration { + // this value is used consistently across all Thrift libraries + pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024; + + // this value is used consistently across all Thrift libraries + pub const DEFAULT_MAX_FRAME_SIZE: usize = 16_384_000; + + pub const DEFAULT_RECURSION_LIMIT: usize = 64; + + pub const DEFAULT_CONTAINER_LIMIT: Option = None; + + pub const DEFAULT_STRING_LIMIT: usize = 100 * 1024 * 1024; + + pub fn no_limits() -> Self { + Self { + max_message_size: None, + max_frame_size: None, + max_recursion_depth: None, + max_container_size: None, + max_string_size: None, + } + } + + pub fn max_message_size(&self) -> Option { + self.max_message_size + } + + pub fn max_frame_size(&self) -> Option { + self.max_frame_size + } + + pub fn max_recursion_depth(&self) -> Option { + self.max_recursion_depth + } + + pub fn max_container_size(&self) -> Option { + self.max_container_size + } + + pub fn max_string_size(&self) -> Option { + self.max_string_size + } + + pub fn builder() -> TConfigurationBuilder { + TConfigurationBuilder::default() + } +} + +impl Default for TConfiguration { + fn default() -> Self { + Self { + max_message_size: Some(Self::DEFAULT_MAX_MESSAGE_SIZE), + max_frame_size: Some(Self::DEFAULT_MAX_FRAME_SIZE), + max_recursion_depth: Some(Self::DEFAULT_RECURSION_LIMIT), + max_container_size: Self::DEFAULT_CONTAINER_LIMIT, + max_string_size: Some(Self::DEFAULT_STRING_LIMIT), + } + } +} + +#[derive(Debug, Default)] +pub struct TConfigurationBuilder { + config: TConfiguration, +} + +impl TConfigurationBuilder { + pub fn max_message_size(mut self, limit: Option) -> Self { + self.config.max_message_size = limit; + self + } + + pub fn max_frame_size(mut self, limit: Option) -> Self { + self.config.max_frame_size = limit; + self + } + + pub fn max_recursion_depth(mut self, limit: Option) -> Self { + self.config.max_recursion_depth = limit; + self + } + + pub fn max_container_size(mut self, limit: Option) -> Self { + self.config.max_container_size = limit; + self + } + + pub fn max_string_size(mut self, limit: Option) -> Self { + self.config.max_string_size = limit; + self + } + + pub fn build(self) -> crate::Result { + if let (Some(frame_size), Some(message_size)) = + (self.config.max_frame_size, self.config.max_message_size) + { + if frame_size > message_size { + // FIXME: This should probably be a different error type. + return Err(crate::Error::Application(crate::ApplicationError::new( + crate::ApplicationErrorKind::Unknown, + format!( + "Invalid configuration: max_frame_size ({}) cannot exceed max_message_size ({})", + frame_size, message_size + ), + ))); + } + } + + Ok(self.config) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_custom_configuration_builder() { + let config = TConfiguration::builder() + .max_message_size(Some(1024)) + .max_frame_size(Some(512)) + .max_recursion_depth(Some(10)) + .max_container_size(Some(100)) + .max_string_size(Some(256)) + .build() + .unwrap(); + + assert_eq!(config.max_message_size(), Some(1024)); + assert_eq!(config.max_frame_size(), Some(512)); + assert_eq!(config.max_recursion_depth(), Some(10)); + assert_eq!(config.max_container_size(), Some(100)); + assert_eq!(config.max_string_size(), Some(256)); + } + + #[test] + fn test_invalid_configuration() { + // Test that builder catches invalid configurations + let result = TConfiguration::builder() + .max_frame_size(Some(1000)) + .max_message_size(Some(500)) // frame size > message size is invalid + .build(); + + assert!(result.is_err()); + match result { + Err(crate::Error::Application(e)) => { + assert!(e.message.contains("max_frame_size")); + assert!(e.message.contains("cannot exceed max_message_size")); + } + _ => panic!("Expected Application error"), + } + } +} diff --git a/lib/rs/src/lib.rs b/lib/rs/src/lib.rs index 2f6018810b3..d3804eccc08 100644 --- a/lib/rs/src/lib.rs +++ b/lib/rs/src/lib.rs @@ -21,10 +21,11 @@ //! Thrift server and client. It is divided into the following modules: //! //! 1. errors -//! 2. protocol -//! 3. transport -//! 4. server -//! 5. autogen +//! 2. configuration +//! 3. protocol +//! 4. transport +//! 5. server +//! 6. autogen //! //! The modules are layered as shown in the diagram below. The `autogen'd` //! layer is generated by the Thrift compiler's Rust plugin. It uses the @@ -82,6 +83,9 @@ pub use crate::errors::*; mod autogen; pub use crate::autogen::*; +mod configuration; +pub use crate::configuration::*; + /// Result type returned by all runtime library functions. /// /// As is convention this is a typedef of `std::result::Result` diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs index b4b51f682e0..596285fb9f1 100644 --- a/lib/rs/src/protocol/binary.rs +++ b/lib/rs/src/protocol/binary.rs @@ -24,7 +24,7 @@ use super::{ }; use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; use crate::transport::{TReadTransport, TWriteTransport}; -use crate::{ProtocolError, ProtocolErrorKind}; +use crate::{ProtocolError, ProtocolErrorKind, TConfiguration}; const BINARY_PROTOCOL_VERSION_1: u32 = 0x8001_0000; @@ -57,6 +57,8 @@ where { strict: bool, pub transport: T, // FIXME: shouldn't be public + config: TConfiguration, + recursion_depth: usize, } impl TBinaryInputProtocol @@ -67,8 +69,29 @@ where /// /// Set `strict` to `true` if all incoming messages contain the protocol /// version number in the protocol header. - pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol { - TBinaryInputProtocol { strict, transport } + pub fn new(transport: T, strict: bool) -> Self { + Self::with_config(transport, strict, TConfiguration::default()) + } + + pub fn with_config(transport: T, strict: bool, config: TConfiguration) -> Self { + TBinaryInputProtocol { + strict, + transport, + config, + recursion_depth: 0, + } + } + + fn check_recursion_depth(&self) -> crate::Result<()> { + if let Some(limit) = self.config.max_recursion_depth() { + if self.recursion_depth >= limit { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::DepthLimit, + format!("Maximum recursion depth {} exceeded", limit), + ))); + } + } + Ok(()) } } @@ -78,6 +101,7 @@ where { #[allow(clippy::collapsible_if)] fn read_message_begin(&mut self) -> crate::Result { + // TODO: Once specialization is stable, call the message size tracking here let mut first_bytes = vec![0; 4]; self.transport.read_exact(&mut first_bytes[..])?; @@ -130,10 +154,13 @@ where } fn read_struct_begin(&mut self) -> crate::Result> { + self.check_recursion_depth()?; + self.recursion_depth += 1; Ok(None) } fn read_struct_end(&mut self) -> crate::Result<()> { + self.recursion_depth -= 1; Ok(()) } @@ -154,8 +181,28 @@ where } fn read_bytes(&mut self) -> crate::Result> { - let num_bytes = self.transport.read_i32::()? as usize; - let mut buf = vec![0u8; num_bytes]; + let num_bytes = self.transport.read_i32::()?; + + if num_bytes < 0 { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::NegativeSize, + format!("Negative byte array size: {}", num_bytes), + ))); + } + + if let Some(max_size) = self.config.max_string_size() { + if num_bytes as usize > max_size { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::SizeLimit, + format!( + "Byte array size {} exceeds maximum allowed size of {}", + num_bytes, max_size + ), + ))); + } + } + + let mut buf = vec![0u8; num_bytes as usize]; self.transport .read_exact(&mut buf) .map(|_| buf) @@ -206,6 +253,8 @@ where fn read_list_begin(&mut self) -> crate::Result { let element_type: TType = self.read_byte().and_then(field_type_from_u8)?; let size = self.read_i32()?; + let min_element_size = self.min_serialized_size(element_type); + super::check_container_size(&self.config, size, min_element_size)?; Ok(TListIdentifier::new(element_type, size)) } @@ -216,6 +265,8 @@ where fn read_set_begin(&mut self) -> crate::Result { let element_type: TType = self.read_byte().and_then(field_type_from_u8)?; let size = self.read_i32()?; + let min_element_size = self.min_serialized_size(element_type); + super::check_container_size(&self.config, size, min_element_size)?; Ok(TSetIdentifier::new(element_type, size)) } @@ -227,6 +278,12 @@ where let key_type: TType = self.read_byte().and_then(field_type_from_u8)?; let value_type: TType = self.read_byte().and_then(field_type_from_u8)?; let size = self.read_i32()?; + + let key_min_size = self.min_serialized_size(key_type); + let value_min_size = self.min_serialized_size(value_type); + let element_size = key_min_size + value_min_size; + super::check_container_size(&self.config, size, element_size)?; + Ok(TMapIdentifier::new(key_type, value_type, size)) } @@ -240,6 +297,26 @@ where fn read_byte(&mut self) -> crate::Result { self.transport.read_u8().map_err(From::from) } + + fn min_serialized_size(&self, field_type: TType) -> usize { + match field_type { + TType::Stop => 1, // 1 byte minimum + TType::Void => 1, // 1 byte minimum + TType::Bool => 1, // 1 byte + TType::I08 => 1, // 1 byte + TType::Double => 8, // 8 bytes + TType::I16 => 2, // 2 bytes + TType::I32 => 4, // 4 bytes + TType::I64 => 8, // 8 bytes + TType::String => 4, // 4 bytes for length prefix + TType::Struct => 1, // 1 byte minimum (stop field) + TType::Map => 4, // 4 bytes size + TType::Set => 4, // 4 bytes size + TType::List => 4, // 4 bytes size + TType::Uuid => 16, // 16 bytes + TType::Utf7 => 1, // 1 byte + } + } } /// Factory for creating instances of `TBinaryInputProtocol`. @@ -514,14 +591,13 @@ fn field_type_from_u8(b: u8) -> crate::Result { #[cfg(test)] mod tests { + use super::*; use crate::protocol::{ TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType, }; use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; - use super::*; - #[test] fn must_write_strict_message_call_begin() { let (_, mut o_prot) = test_objects(true); @@ -759,13 +835,26 @@ mod tests { fn must_round_trip_list_begin() { let (mut i_prot, mut o_prot) = test_objects(true); - let ident = TListIdentifier::new(TType::List, 900); + let ident = TListIdentifier::new(TType::I32, 4); assert!(o_prot.write_list_begin(&ident).is_ok()); + assert!(o_prot.write_i32(10).is_ok()); + assert!(o_prot.write_i32(20).is_ok()); + assert!(o_prot.write_i32(30).is_ok()); + assert!(o_prot.write_i32(40).is_ok()); + + assert!(o_prot.write_list_end().is_ok()); copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_list_begin()); assert_eq!(&received_ident, &ident); + + assert_eq!(i_prot.read_i32().unwrap(), 10); + assert_eq!(i_prot.read_i32().unwrap(), 20); + assert_eq!(i_prot.read_i32().unwrap(), 30); + assert_eq!(i_prot.read_i32().unwrap(), 40); + + assert!(i_prot.read_list_end().is_ok()); } #[test] @@ -789,14 +878,25 @@ mod tests { fn must_round_trip_set_begin() { let (mut i_prot, mut o_prot) = test_objects(true); - let ident = TSetIdentifier::new(TType::I64, 2000); + let ident = TSetIdentifier::new(TType::I64, 3); assert!(o_prot.write_set_begin(&ident).is_ok()); + assert!(o_prot.write_i64(123).is_ok()); + assert!(o_prot.write_i64(456).is_ok()); + assert!(o_prot.write_i64(789).is_ok()); + + assert!(o_prot.write_set_end().is_ok()); copy_write_buffer_to_read_buffer!(o_prot); let received_ident_result = i_prot.read_set_begin(); assert!(received_ident_result.is_ok()); assert_eq!(&received_ident_result.unwrap(), &ident); + + assert_eq!(i_prot.read_i64().unwrap(), 123); + assert_eq!(i_prot.read_i64().unwrap(), 456); + assert_eq!(i_prot.read_i64().unwrap(), 789); + + assert!(i_prot.read_set_end().is_ok()); } #[test] @@ -820,13 +920,26 @@ mod tests { fn must_round_trip_map_begin() { let (mut i_prot, mut o_prot) = test_objects(true); - let ident = TMapIdentifier::new(TType::Map, TType::Set, 100); + let ident = TMapIdentifier::new(TType::String, TType::I32, 2); assert!(o_prot.write_map_begin(&ident).is_ok()); + assert!(o_prot.write_string("key1").is_ok()); + assert!(o_prot.write_i32(100).is_ok()); + assert!(o_prot.write_string("key2").is_ok()); + assert!(o_prot.write_i32(200).is_ok()); + + assert!(o_prot.write_map_end().is_ok()); copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_map_begin()); assert_eq!(&received_ident, &ident); + + assert_eq!(i_prot.read_string().unwrap(), "key1"); + assert_eq!(i_prot.read_i32().unwrap(), 100); + assert_eq!(i_prot.read_string().unwrap(), "key2"); + assert_eq!(i_prot.read_i32().unwrap(), 200); + + assert!(i_prot.read_map_end().is_ok()); } #[test] @@ -963,7 +1076,7 @@ mod tests { TBinaryInputProtocol>, TBinaryOutputProtocol>, ) { - let mem = TBufferChannel::with_capacity(40, 40); + let mem = TBufferChannel::with_capacity(200, 200); let (r_mem, w_mem) = mem.split().unwrap(); @@ -981,4 +1094,154 @@ mod tests { assert!(write_fn(&mut o_prot).is_ok()); assert_eq!(o_prot.transport.write_bytes().len(), 0); } + + #[test] + fn must_enforce_recursion_depth_limit() { + let mem = TBufferChannel::with_capacity(40, 40); + let (r_mem, _) = mem.split().unwrap(); + + let config = TConfiguration::builder() + .max_recursion_depth(Some(2)) + .build() + .unwrap(); + let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config); + + assert!(i_prot.read_struct_begin().is_ok()); + assert_eq!(i_prot.recursion_depth, 1); + + assert!(i_prot.read_struct_begin().is_ok()); + assert_eq!(i_prot.recursion_depth, 2); + + let result = i_prot.read_struct_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::DepthLimit); + } + _ => panic!("Expected protocol error with DepthLimit"), + } + + assert!(i_prot.read_struct_end().is_ok()); + assert_eq!(i_prot.recursion_depth, 1); + assert!(i_prot.read_struct_end().is_ok()); + assert_eq!(i_prot.recursion_depth, 0); + } + + #[test] + fn must_reject_negative_container_sizes() { + let mem = TBufferChannel::with_capacity(40, 40); + let (r_mem, mut w_mem) = mem.split().unwrap(); + + let mut i_prot = TBinaryInputProtocol::new(r_mem, true); + + w_mem.set_readable_bytes(&[0x0F, 0xFF, 0xFF, 0xFF, 0xFF]); + + let result = i_prot.read_list_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::NegativeSize); + } + _ => panic!("Expected protocol error with NegativeSize"), + } + } + + #[test] + fn must_enforce_container_size_limit() { + let mem = TBufferChannel::with_capacity(40, 40); + let (r_mem, mut w_mem) = mem.split().unwrap(); + + let config = TConfiguration::builder() + .max_container_size(Some(100)) + .build() + .unwrap(); + + let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config); + + w_mem.set_readable_bytes(&[0x0F, 0x00, 0x00, 0x00, 0xC8]); + + let result = i_prot.read_list_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::SizeLimit); + assert!(e + .message + .contains("Container size 200 exceeds maximum allowed size of 100")); + } + _ => panic!("Expected protocol error with SizeLimit"), + } + } + + #[test] + fn must_allow_containers_within_limit() { + let mem = TBufferChannel::with_capacity(200, 200); + let (r_mem, mut w_mem) = mem.split().unwrap(); + + // Create protocol with container limit of 100 + let config = TConfiguration::builder() + .max_container_size(Some(100)) + .build() + .unwrap(); + let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config); + + let mut data = vec![0x08]; // TType::I32 + data.extend_from_slice(&5i32.to_be_bytes()); // size = 5 + + for i in 1i32..=5i32 { + data.extend_from_slice(&(i * 10).to_be_bytes()); + } + + w_mem.set_readable_bytes(&data); + + let result = i_prot.read_list_begin(); + assert!(result.is_ok()); + let list_ident = result.unwrap(); + assert_eq!(list_ident.size, 5); + assert_eq!(list_ident.element_type, TType::I32); + } + + #[test] + fn must_enforce_string_size_limit() { + let mem = TBufferChannel::with_capacity(100, 100); + let (r_mem, mut w_mem) = mem.split().unwrap(); + + let config = TConfiguration::builder() + .max_string_size(Some(1000)) + .build() + .unwrap(); + let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config); + + w_mem.set_readable_bytes(&[0x00, 0x00, 0x07, 0xD0]); + + let result = i_prot.read_string(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::SizeLimit); + assert!(e + .message + .contains("Byte array size 2000 exceeds maximum allowed size of 1000")); + } + _ => panic!("Expected protocol error with SizeLimit"), + } + } + + #[test] + fn must_allow_strings_within_limit() { + let mem = TBufferChannel::with_capacity(100, 100); + let (r_mem, mut w_mem) = mem.split().unwrap(); + + let config = TConfiguration::builder() + .max_string_size(Some(1000)) + .build() + .unwrap(); + let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config); + + w_mem.set_readable_bytes(&[0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o']); + + let result = i_prot.read_string(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "hello"); + } } diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs index 4dc45cac2fd..7e1a7510e0c 100644 --- a/lib/rs/src/protocol/compact.rs +++ b/lib/rs/src/protocol/compact.rs @@ -26,6 +26,7 @@ use super::{ }; use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; use crate::transport::{TReadTransport, TWriteTransport}; +use crate::{ProtocolError, ProtocolErrorKind, TConfiguration}; const COMPACT_PROTOCOL_ID: u8 = 0x82; const COMPACT_VERSION: u8 = 0x01; @@ -64,6 +65,10 @@ where pending_read_bool_value: Option, // Underlying transport used for byte-level operations. transport: T, + // Configuration + config: TConfiguration, + // Current recursion depth + recursion_depth: usize, } impl TCompactInputProtocol @@ -72,11 +77,18 @@ where { /// Create a `TCompactInputProtocol` that reads bytes from `transport`. pub fn new(transport: T) -> TCompactInputProtocol { + Self::with_config(transport, TConfiguration::default()) + } + + /// Create a `TCompactInputProtocol` with custom configuration. + pub fn with_config(transport: T, config: TConfiguration) -> TCompactInputProtocol { TCompactInputProtocol { last_read_field_id: 0, read_field_id_stack: Vec::new(), pending_read_bool_value: None, transport, + config, + recursion_depth: 0, } } @@ -92,8 +104,23 @@ where self.transport.read_varint::()? as i32 }; + let min_element_size = self.min_serialized_size(element_type); + super::check_container_size(&self.config, element_count, min_element_size)?; + Ok((element_type, element_count)) } + + fn check_recursion_depth(&self) -> crate::Result<()> { + if let Some(limit) = self.config.max_recursion_depth() { + if self.recursion_depth >= limit { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::DepthLimit, + format!("Maximum recursion depth {} exceeded", limit), + ))); + } + } + Ok(()) + } } impl TInputProtocol for TCompactInputProtocol @@ -101,6 +128,7 @@ where T: TReadTransport, { fn read_message_begin(&mut self) -> crate::Result { + // TODO: Once specialization is stable, call the message size tracking here let compact_id = self.read_byte()?; if compact_id != COMPACT_PROTOCOL_ID { Err(crate::Error::Protocol(crate::ProtocolError { @@ -145,12 +173,15 @@ where } fn read_struct_begin(&mut self) -> crate::Result> { + self.check_recursion_depth()?; + self.recursion_depth += 1; self.read_field_id_stack.push(self.last_read_field_id); self.last_read_field_id = 0; Ok(None) } fn read_struct_end(&mut self) -> crate::Result<()> { + self.recursion_depth -= 1; self.last_read_field_id = self .read_field_id_stack .pop() @@ -227,6 +258,19 @@ where fn read_bytes(&mut self) -> crate::Result> { let len = self.transport.read_varint::()?; + + if let Some(max_size) = self.config.max_string_size() { + if len as usize > max_size { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::SizeLimit, + format!( + "Byte array size {} exceeds maximum allowed size of {}", + len, max_size + ), + ))); + } + } + let mut buf = vec![0u8; len as usize]; self.transport .read_exact(&mut buf) @@ -291,6 +335,12 @@ where let type_header = self.read_byte()?; let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?; let val_type = collection_u8_to_type(type_header & 0x0F)?; + + let key_min_size = self.min_serialized_size(key_type); + let value_min_size = self.min_serialized_size(val_type); + let element_size = key_min_size + value_min_size; + super::check_container_size(&self.config, element_count, element_size)?; + Ok(TMapIdentifier::new(key_type, val_type, element_count)) } } @@ -309,6 +359,30 @@ where .map_err(From::from) .map(|_| buf[0]) } + + fn min_serialized_size(&self, field_type: TType) -> usize { + compact_protocol_min_serialized_size(field_type) + } +} + +pub(crate) fn compact_protocol_min_serialized_size(field_type: TType) -> usize { + match field_type { + TType::Stop => 1, // 1 byte + TType::Void => 1, // 1 byte + TType::Bool => 1, // 1 byte + TType::I08 => 1, // 1 byte + TType::Double => 8, // 8 bytes (not varint encoded) + TType::I16 => 1, // 1 byte minimum (varint) + TType::I32 => 1, // 1 byte minimum (varint) + TType::I64 => 1, // 1 byte minimum (varint) + TType::String => 1, // 1 byte minimum for length (varint) + TType::Struct => 1, // 1 byte minimum (stop field) + TType::Map => 1, // 1 byte minimum + TType::Set => 1, // 1 byte minimum + TType::List => 1, // 1 byte minimum + TType::Uuid => 16, // 16 bytes + TType::Utf7 => 1, // 1 byte + } } impl io::Seek for TCompactInputProtocol @@ -2573,14 +2647,25 @@ mod tests { fn must_round_trip_small_sized_list_begin() { let (mut i_prot, mut o_prot) = test_objects(); - let ident = TListIdentifier::new(TType::I08, 10); - + let ident = TListIdentifier::new(TType::I32, 3); assert_success!(o_prot.write_list_begin(&ident)); + assert_success!(o_prot.write_i32(100)); + assert_success!(o_prot.write_i32(200)); + assert_success!(o_prot.write_i32(300)); + + assert_success!(o_prot.write_list_end()); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_list_begin()); assert_eq!(&res, &ident); + + assert_eq!(i_prot.read_i32().unwrap(), 100); + assert_eq!(i_prot.read_i32().unwrap(), 200); + assert_eq!(i_prot.read_i32().unwrap(), 300); + + assert_success!(i_prot.read_list_end()); } #[test] @@ -2600,10 +2685,9 @@ mod tests { #[test] fn must_round_trip_large_sized_list_begin() { - let (mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects_no_limits(); let ident = TListIdentifier::new(TType::Set, 47381); - assert_success!(o_prot.write_list_begin(&ident)); copy_write_buffer_to_read_buffer!(o_prot); @@ -2632,14 +2716,25 @@ mod tests { fn must_round_trip_small_sized_set_begin() { let (mut i_prot, mut o_prot) = test_objects(); - let ident = TSetIdentifier::new(TType::I16, 7); - + let ident = TSetIdentifier::new(TType::I16, 3); assert_success!(o_prot.write_set_begin(&ident)); + assert_success!(o_prot.write_i16(111)); + assert_success!(o_prot.write_i16(222)); + assert_success!(o_prot.write_i16(333)); + + assert_success!(o_prot.write_set_end()); + copy_write_buffer_to_read_buffer!(o_prot); let res = assert_success!(i_prot.read_set_begin()); assert_eq!(&res, &ident); + + assert_eq!(i_prot.read_i16().unwrap(), 111); + assert_eq!(i_prot.read_i16().unwrap(), 222); + assert_eq!(i_prot.read_i16().unwrap(), 333); + + assert_success!(i_prot.read_set_end()); } #[test] @@ -2658,10 +2753,9 @@ mod tests { #[test] fn must_round_trip_large_sized_set_begin() { - let (mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects_no_limits(); let ident = TSetIdentifier::new(TType::Map, 3_928_429); - assert_success!(o_prot.write_set_begin(&ident)); copy_write_buffer_to_read_buffer!(o_prot); @@ -2725,10 +2819,9 @@ mod tests { #[test] fn must_round_trip_map_begin() { - let (mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects_no_limits(); let ident = TMapIdentifier::new(TType::Map, TType::List, 1_928_349); - assert_success!(o_prot.write_map_begin(&ident)); copy_write_buffer_to_read_buffer!(o_prot); @@ -2804,7 +2897,7 @@ mod tests { TCompactInputProtocol>, TCompactOutputProtocol>, ) { - let mem = TBufferChannel::with_capacity(80, 80); + let mem = TBufferChannel::with_capacity(200, 200); let (r_mem, w_mem) = mem.split().unwrap(); @@ -2814,6 +2907,20 @@ mod tests { (i_prot, o_prot) } + fn test_objects_no_limits() -> ( + TCompactInputProtocol>, + TCompactOutputProtocol>, + ) { + let mem = TBufferChannel::with_capacity(200, 200); + + let (r_mem, w_mem) = mem.split().unwrap(); + + let i_prot = TCompactInputProtocol::with_config(r_mem, TConfiguration::no_limits()); + let o_prot = TCompactOutputProtocol::new(w_mem); + + (i_prot, o_prot) + } + #[test] fn must_read_write_double() { let (mut i_prot, mut o_prot) = test_objects(); @@ -2883,4 +2990,248 @@ mod tests { assert_success!(i_prot.read_list_end()); } + + #[test] + fn must_enforce_recursion_depth_limit() { + let channel = TBufferChannel::with_capacity(100, 100); + + // Create a configuration with a small recursion limit + let config = TConfiguration::builder() + .max_recursion_depth(Some(2)) + .build() + .unwrap(); + + let mut protocol = TCompactInputProtocol::with_config(channel, config); + + // First struct - should succeed + assert!(protocol.read_struct_begin().is_ok()); + + // Second struct - should succeed (at limit) + assert!(protocol.read_struct_begin().is_ok()); + + // Third struct - should fail (exceeds limit) + let result = protocol.read_struct_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::DepthLimit); + } + _ => panic!("Expected protocol error with DepthLimit"), + } + } + + #[test] + fn must_check_container_size_overflow() { + // Configure a small message size limit + let config = TConfiguration::builder() + .max_message_size(Some(1000)) + .max_frame_size(Some(1000)) + .build() + .unwrap(); + let transport = TBufferChannel::with_capacity(100, 0); + let mut i_prot = TCompactInputProtocol::with_config(transport, config); + + // Write a list header that would require more memory than message size limit + // List of 100 UUIDs (16 bytes each) = 1600 bytes > 1000 limit + i_prot.transport.set_readable_bytes(&[ + 0xFD, // element type UUID (0x0D) | count in next bytes (0xF0) + 0x64, // varint 100 + ]); + + let result = i_prot.read_list_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::SizeLimit); + assert!(e + .message + .contains("1600 bytes, exceeding message size limit of 1000")); + } + _ => panic!("Expected protocol error with SizeLimit"), + } + } + + #[test] + fn must_reject_negative_container_sizes() { + let mut channel = TBufferChannel::with_capacity(100, 100); + + let mut protocol = TCompactInputProtocol::new(channel.clone()); + + // Write header with negative size when decoded + // In compact protocol, lists/sets use a header byte followed by size + // We'll use 0x0F for element type and then a varint-encoded negative number + channel.set_readable_bytes(&[ + 0xF0, // Header: 15 in upper nibble (triggers varint read), List type in lower + 0xFF, 0xFF, 0xFF, 0xFF, 0x0F, // Varint encoding of -1 + ]); + + let result = protocol.read_list_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::NegativeSize); + } + _ => panic!("Expected protocol error with NegativeSize"), + } + } + + #[test] + fn must_enforce_container_size_limit() { + let channel = TBufferChannel::with_capacity(100, 100); + let (r_channel, mut w_channel) = channel.split().unwrap(); + + // Create protocol with explicit container size limit + let config = TConfiguration::builder() + .max_container_size(Some(1000)) + .build() + .unwrap(); + let mut protocol = TCompactInputProtocol::with_config(r_channel, config); + + // Write header with large size + // Compact protocol: 0xF0 means size >= 15 is encoded as varint + // Then we write a varint encoding 10000 (exceeds our limit of 1000) + w_channel.set_readable_bytes(&[ + 0xF0, // Header: 15 in upper nibble (triggers varint read), element type in lower + 0x90, 0x4E, // Varint encoding of 10000 + ]); + + let result = protocol.read_list_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::SizeLimit); + assert!(e.message.contains("exceeds maximum allowed size")); + } + _ => panic!("Expected protocol error with SizeLimit"), + } + } + + #[test] + fn must_handle_varint_size_overflow() { + // Test that compact protocol properly handles varint-encoded sizes that would cause overflow + let mut channel = TBufferChannel::with_capacity(100, 100); + + let mut protocol = TCompactInputProtocol::new(channel.clone()); + + // Create input that encodes a very large size using varint encoding + // 0xFA = list header with size >= 15 (so size follows as varint) + // Then multiple 0xFF bytes which in varint encoding create a very large number + channel.set_readable_bytes(&[ + 0xFA, // List header: size >= 15, element type = 0x0A + 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, // Varint encoding of a huge number + ]); + + let result = protocol.read_list_begin(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + // The varint decoder might interpret this as negative, which is also fine + assert!( + e.kind == ProtocolErrorKind::SizeLimit + || e.kind == ProtocolErrorKind::NegativeSize, + "Expected SizeLimit or NegativeSize but got {:?}", + e.kind + ); + } + _ => panic!("Expected protocol error"), + } + } + + #[test] + fn must_enforce_string_size_limit() { + let channel = TBufferChannel::with_capacity(100, 100); + let (r_channel, mut w_channel) = channel.split().unwrap(); + + // Create protocol with string limit of 100 bytes + let config = TConfiguration::builder() + .max_string_size(Some(100)) + .build() + .unwrap(); + let mut protocol = TCompactInputProtocol::with_config(r_channel, config); + + // Write a varint-encoded string size that exceeds the limit + w_channel.set_readable_bytes(&[ + 0xC8, 0x01, // Varint encoding of 200 + ]); + + let result = protocol.read_string(); + assert!(result.is_err()); + match result { + Err(crate::Error::Protocol(e)) => { + assert_eq!(e.kind, ProtocolErrorKind::SizeLimit); + assert!(e.message.contains("exceeds maximum allowed size")); + } + _ => panic!("Expected protocol error with SizeLimit"), + } + } + + #[test] + fn must_allow_no_limit_configuration() { + let channel = TBufferChannel::with_capacity(40, 40); + + let config = TConfiguration::no_limits(); + let mut protocol = TCompactInputProtocol::with_config(channel, config); + + // Should be able to nest structs deeply without limit + for _ in 0..100 { + assert!(protocol.read_struct_begin().is_ok()); + } + + for _ in 0..100 { + assert!(protocol.read_struct_end().is_ok()); + } + } + + #[test] + fn must_allow_containers_within_limit() { + let channel = TBufferChannel::with_capacity(200, 200); + let (r_channel, mut w_channel) = channel.split().unwrap(); + + // Create protocol with container limit of 100 + let config = TConfiguration::builder() + .max_container_size(Some(100)) + .build() + .unwrap(); + let mut protocol = TCompactInputProtocol::with_config(r_channel, config); + + // Write a list with 5 i32 elements (well within limit of 100) + // Compact protocol: size < 15 is encoded in header + w_channel.set_readable_bytes(&[ + 0x55, // Header: size=5, element type=5 (i32) + // 5 varint-encoded i32 values + 0x0A, // 10 + 0x14, // 20 + 0x1E, // 30 + 0x28, // 40 + 0x32, // 50 + ]); + + let result = protocol.read_list_begin(); + assert!(result.is_ok()); + let list_ident = result.unwrap(); + assert_eq!(list_ident.size, 5); + assert_eq!(list_ident.element_type, TType::I32); + } + + #[test] + fn must_allow_strings_within_limit() { + let channel = TBufferChannel::with_capacity(100, 100); + let (r_channel, mut w_channel) = channel.split().unwrap(); + + let config = TConfiguration::builder() + .max_string_size(Some(1000)) + .build() + .unwrap(); + let mut protocol = TCompactInputProtocol::with_config(r_channel, config); + + // Write a string "hello" (5 bytes, well within limit) + w_channel.set_readable_bytes(&[ + 0x05, // Varint-encoded length: 5 + b'h', b'e', b'l', b'l', b'o', + ]); + + let result = protocol.read_string(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "hello"); + } } diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs index 8bbbb3c7cf0..573978a758b 100644 --- a/lib/rs/src/protocol/mod.rs +++ b/lib/rs/src/protocol/mod.rs @@ -62,7 +62,7 @@ use std::fmt; use std::fmt::{Display, Formatter}; use crate::transport::{TReadTransport, TWriteTransport}; -use crate::{ProtocolError, ProtocolErrorKind}; +use crate::{ProtocolError, ProtocolErrorKind, TConfiguration}; #[cfg(test)] macro_rules! assert_eq_written_bytes { @@ -262,6 +262,15 @@ pub trait TInputProtocol { /// /// This method should **never** be used in generated code. fn read_byte(&mut self) -> crate::Result; + + /// Get the minimum number of bytes a type will consume on the wire. + /// This picks the minimum possible across all protocols (so currently matches the compact protocol). + /// + /// This is used for pre-allocation size checks. + /// The actual data may be larger (e.g., for strings, lists, etc.). + fn min_serialized_size(&self, field_type: TType) -> usize { + self::compact::compact_protocol_min_serialized_size(field_type) + } } /// Converts Thrift identifiers, primitives, containers or structs into a @@ -444,6 +453,10 @@ where fn read_byte(&mut self) -> crate::Result { (**self).read_byte() } + + fn min_serialized_size(&self, field_type: TType) -> usize { + (**self).min_serialized_size(field_type) + } } impl

TOutputProtocol for Box

@@ -565,7 +578,7 @@ where /// let protocol = factory.create(Box::new(channel)); /// ``` pub trait TInputProtocolFactory { - // Create a `TInputProtocol` that reads bytes from `transport`. + /// Create a `TInputProtocol` that reads bytes from `transport`. fn create(&self, transport: Box) -> Box; } @@ -920,6 +933,69 @@ pub fn verify_required_field_exists(field_name: &str, field: &Option) -> c } } +/// Common container size validation used by all protocols. +/// +/// Checks that: +/// - Container size is not negative +/// - Container size doesn't exceed configured maximum +/// - Container size * element size doesn't overflow +/// - Container memory requirements don't exceed message size limit +pub(crate) fn check_container_size( + config: &TConfiguration, + container_size: i32, + element_size: usize, +) -> crate::Result<()> { + // Check for negative size + if container_size < 0 { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::NegativeSize, + format!("Negative container size: {}", container_size), + ))); + } + + let size_as_usize = container_size as usize; + + // Check against configured max container size + if let Some(max_size) = config.max_container_size() { + if size_as_usize > max_size { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::SizeLimit, + format!( + "Container size {} exceeds maximum allowed size of {}", + container_size, max_size + ), + ))); + } + } + + // Check for potential overflow + if let Some(min_bytes_needed) = size_as_usize.checked_mul(element_size) { + // TODO: When Rust trait specialization stabilizes, we can add more precise checks + // for transports that track exact remaining bytes. For now, we use the message + // size limit as a best-effort check. + if let Some(max_message_size) = config.max_message_size() { + if min_bytes_needed > max_message_size { + return Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::SizeLimit, + format!( + "Container would require {} bytes, exceeding message size limit of {}", + min_bytes_needed, max_message_size + ), + ))); + } + } + Ok(()) + } else { + Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::SizeLimit, + format!( + "Container size {} with element size {} bytes would result in overflow", + container_size, element_size + ), + ))) + } +} + /// Extract the field id from a Thrift field identifier. /// /// `field_ident` must *not* have `TFieldIdentifier.field_type` of type `TType::Stop`. diff --git a/lib/rs/src/transport/framed.rs b/lib/rs/src/transport/framed.rs index d8a7448725f..cf959be8bed 100644 --- a/lib/rs/src/transport/framed.rs +++ b/lib/rs/src/transport/framed.rs @@ -21,6 +21,7 @@ use std::io; use std::io::{Read, Write}; use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; +use crate::TConfiguration; /// Default capacity of the read buffer in bytes. const READ_CAPACITY: usize = 4096; @@ -61,6 +62,7 @@ where pos: usize, cap: usize, chan: C, + config: TConfiguration, } impl TFramedReadTransport @@ -81,6 +83,7 @@ where pos: 0, cap: 0, chan: channel, + config: TConfiguration::default(), } } } @@ -91,7 +94,28 @@ where { fn read(&mut self, b: &mut [u8]) -> io::Result { if self.cap - self.pos == 0 { - let message_size = self.chan.read_i32::()? as usize; + let frame_size_bytes = self.chan.read_i32::()?; + + if frame_size_bytes < 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Negative frame size: {}", frame_size_bytes), + )); + } + + let message_size = frame_size_bytes as usize; + + if let Some(max_frame) = self.config.max_frame_size() { + if message_size > max_frame { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Frame size {} exceeds maximum allowed size of {}", + message_size, max_frame + ), + )); + } + } let buf_capacity = cmp::max(message_size, READ_CAPACITY); self.buf.resize(buf_capacity, 0); @@ -125,7 +149,6 @@ impl TReadTransportFactory for TFramedReadTransportFactory { Box::new(TFramedReadTransport::new(channel)) } } - /// Transport that writes framed messages. /// /// A `TFramedWriteTransport` maintains a fixed-size internal write buffer. All