diff --git a/.gitignore b/.gitignore index 38aa775e..cb65f5ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target artifacts/ +*.~undo-tree~ diff --git a/wolfssl-sys/build.rs b/wolfssl-sys/build.rs index 1c10ee23..899e9761 100644 --- a/wolfssl-sys/build.rs +++ b/wolfssl-sys/build.rs @@ -286,6 +286,8 @@ fn build_wolfssl(wolfssl_src: &Path) -> PathBuf { .enable("dtls-frag-ch", None) // Enable setting the D/TLS MTU size .enable("dtls-mtu", None) + // Enable pre-shared keys + .enable("psk", None) // Enable Secure Renegotiation .enable("secure-renegotiation", None) // Enable single threaded mode diff --git a/wolfssl/src/context.rs b/wolfssl/src/context.rs index 1dafdec8..ad06f07f 100644 --- a/wolfssl/src/context.rs +++ b/wolfssl/src/context.rs @@ -4,8 +4,13 @@ use crate::{ ssl::{Session, SessionConfig}, CurveGroup, Method, NewSessionError, RootCertificate, Secret, SslVerifyMode, }; -use std::os::raw::c_int; -use std::ptr::NonNull; +use std::{ + ffi::{c_void, CStr, CString}, + fmt::Debug, + os::raw::{c_char, c_int, c_uint}, + ptr::NonNull, + sync::Arc, +}; use thiserror::Error; /// Produces a [`Context`] once built. @@ -13,6 +18,7 @@ use thiserror::Error; pub struct ContextBuilder { ctx: NonNull, method: Method, + pre_shared_key_callbacks: Option>, } /// Error creating a [`ContextBuilder`] object. @@ -49,7 +55,11 @@ impl ContextBuilder { let ctx = unsafe { wolfssl_sys::wolfSSL_CTX_new(method_fn.as_ptr()) }; let ctx = NonNull::new(ctx).ok_or(NewContextBuilderError::CreateFailed)?; - Ok(Self { ctx, method }) + Ok(Self { + ctx, + method, + pre_shared_key_callbacks: None, + }) } /// When `cond` is True call fallible `func` on `Self` @@ -393,6 +403,152 @@ impl ContextBuilder { } } + unsafe extern "C" fn psk_server_callback( + ssl: *mut wolfssl_sys::WOLFSSL, + identity_ptr: *const c_char, + key_output_ptr: *mut u8, + max_key_length_c_uint: c_uint, + ) -> c_uint { + debug_assert!(!ssl.is_null()); + debug_assert!(!identity_ptr.is_null()); // this is never null, it points to an array in an `Arrays` struct + debug_assert!(!key_output_ptr.is_null()); + + // SAFETY: identity_ptr is in fact a C string + let identity: &CStr = unsafe { CStr::from_ptr(identity_ptr) }; + let max_key_length: usize = max_key_length_c_uint.try_into().unwrap(); + + // SAFETY: `wolfSSL_get_psk_callback_ctx` is undocumented, but the implementation simply + // gets a field out of the WOLFSSL object. + let stored_cbs_ptr_ptr: *const c_void = + unsafe { wolfssl_sys::wolfSSL_get_psk_callback_ctx(ssl) }; + // SAFETY: This is written in `Session::new_from_wolfssl_pointer` as a pointer to the + // contents of an Box, so should have stable address. The Box is stored at least until the + // end of the session and hence should be alive. + #[allow(clippy::borrowed_box)] + let stored_cbs: &Box = + unsafe { &*(stored_cbs_ptr_ptr as *const Box) }; + + let maybe_key = stored_cbs.psk_server_callback(identity, max_key_length); + match maybe_key { + Some(key) => { + assert!( + key.len() <= max_key_length, + "Key length {} returned by server callback was longer than maximum {}", + key.len(), + max_key_length + ); + // SAFETY: we've verified that the vec length is <= max_key_length, so we won't overrun + // the buffer provided to us. + unsafe { std::ptr::copy(key.as_ptr(), key_output_ptr, key.len()) }; + key.len().try_into().unwrap() + } + None => 0, + } + } + + unsafe extern "C" fn psk_client_callback( + ssl: *mut wolfssl_sys::WOLFSSL, + _hint: *const c_char, + identity_output: *mut c_char, + max_identity_length_c_uint: c_uint, + key_output: *mut u8, + max_key_length_c_uint: c_uint, + ) -> c_uint { + debug_assert!(!ssl.is_null()); + debug_assert!(!identity_output.is_null()); + debug_assert!(!key_output.is_null()); + + let max_identity_length: usize = max_identity_length_c_uint.try_into().unwrap(); + let max_key_length: usize = max_key_length_c_uint.try_into().unwrap(); + + // SAFETY: See `psk_server_callback` + let stored_cbs_ptr_ptr: *const c_void = + unsafe { wolfssl_sys::wolfSSL_get_psk_callback_ctx(ssl) }; + // SAFETY: See `psk_server_callback` + #[allow(clippy::borrowed_box)] + let stored_cbs: &Box = + unsafe { &*(stored_cbs_ptr_ptr as *const Box) }; + + let maybe_result = stored_cbs.psk_client_callback(max_identity_length, max_key_length); + match maybe_result { + Some(PreSharedKeyClientCallbackResult { identity, key }) => { + assert!( + identity.count_bytes() <= max_identity_length, + "Identity length {} was not less than maximum {}", + identity.count_bytes(), + max_identity_length + ); + assert!( + key.len() <= max_key_length, + "Key length {} was not less than maximum {}", + key.len(), + max_key_length + ); + + // SAFETY: See `psk_server_callback`. + unsafe { std::ptr::copy(key.as_ptr(), key_output, key.len()) }; + // SAFETY: See immediately above. +1 to account for nul terminator. + // `max_identity_length` is not including the nul terminator (the definition of the + // `client_identity` field in the `Arrays` struct in wolfssl `internal.h` has length + // `MAX_PSK_ID_LEN + NULL_TERM_LEN`, and `MAX_PSK_ID_LEN` is what is passed as the + // `max_identity_length`) + unsafe { + std::ptr::copy( + identity.as_ptr(), + identity_output, + identity.count_bytes() + 1, + ) + }; + + key.len().try_into().unwrap() + } + None => 0, + } + } + + /// Use a fixed pre-shared key for authentication + /// + /// See also: [with_pre_shared_key_callbacks] + pub fn with_pre_shared_key(self, psk: &[u8]) -> Self { + self.with_pre_shared_key_callbacks(Box::new(FixedPskCallbacks::new(psk))) + } + + /// Use pre-shared key callbacks for authentication + /// + /// Install custom client and server callbacks for pre-shared-key authentication. Calls either + /// `wolfSSL_CTX_set_psk_server_callback` or `wolfSSL_CTX_set_psk_client_callback` appropriately + /// using fixed callbacks provided by wolfssl-rs. Later, during session constrtuction, calls + /// `wolfSSL_set_psk_callback_ctx` to point to make the user-provided safe callbacks accessible + /// in the fixed callback. The fixed callback does the unsafe work and delegates the interesting + /// logic to the safe user-provided callback. + pub fn with_pre_shared_key_callbacks(self, callbacks: Box) -> Self { + if self.method.is_server() { + // SAFETY: `wolfSSL_CTX_set_psk_server_callback` isn't properly documented. It seems the + // only requirement is that the context is valid and the callback will be alive + // throughout the lifetime of the context and any created sessions; our callbacks are + // &'static. + unsafe { + wolfssl_sys::wolfSSL_CTX_set_psk_server_callback( + self.ctx.as_ptr(), + Some(Self::psk_server_callback), + ); + }; + } else { + // SAFETY: See above. + unsafe { + wolfssl_sys::wolfSSL_CTX_set_psk_client_callback( + self.ctx.as_ptr(), + Some(Self::psk_client_callback), + ); + }; + }; + + Self { + pre_shared_key_callbacks: Some(callbacks), + ..self + } + } + /// Wraps `wolfSSL_CTX_UseSecureRenegotiation` /// /// NOTE: No official documentation available for this api from wolfssl @@ -438,6 +594,7 @@ impl ContextBuilder { Context { method: self.method, ctx: ContextPointer(self.ctx), + pre_shared_key_callbacks: self.pre_shared_key_callbacks.map(Arc::new), } } } @@ -513,6 +670,7 @@ unsafe impl Send for WolfsslPointer {} pub struct Context { method: Method, ctx: ContextPointer, + pre_shared_key_callbacks: Option>>, } impl Context { @@ -534,7 +692,7 @@ impl Context { let ssl = WolfsslPointer(NonNull::new(ptr).ok_or(NewSessionError::CreateFailed)?); - Session::new_from_wolfssl_pointer(ssl, config) + Session::new_from_wolfssl_pointer(ssl, config, self.pre_shared_key_callbacks.clone()) } } @@ -555,6 +713,73 @@ impl Drop for Context { } } +/// Returned from the client callback in [PreSharedKeyCallbacks] +pub struct PreSharedKeyClientCallbackResult { + /// Should be an empty string if you don't need multiple identities. Else, an arbitrary string + /// that the server will be able to read to determine which PSK to use. + pub identity: CString, + /// The pre-shared key itself. + pub key: Vec, +} + +/// Callbacks that are used to provide a pre-shared key to wolfSSL. +pub trait PreSharedKeyCallbacks: Debug { + /// Called on the client before starting the connection. + /// + /// The installed wolfSSL callback will return 0 if None is returned from the Rust callback, + /// which means "fail". The wolfSSL docs are unclear what happens when the callback fails in + /// this way. + fn psk_client_callback( + &self, + max_identity_length: usize, + max_key_length: usize, + ) -> Option; + + /// Called on the server after receiving the client hello. + /// + /// Receives the identity set in the client callback. Return the key, or None on failure. + fn psk_server_callback(&self, identity: &CStr, max_key_length: usize) -> Option>; +} + +/// An implementation of PreSharedKeyCallbacks that uses a fixed buffer as the pre-shared key, which +/// is the most common usecase for pre shared keys. +#[derive(Debug)] +struct FixedPskCallbacks { + key: Vec, +} + +impl FixedPskCallbacks { + /// Construct a FixedPskCallbacks object that will always use the given key, ignoring identity. + fn new>>(key: T) -> FixedPskCallbacks { + FixedPskCallbacks { key: key.into() } + } +} + +impl PreSharedKeyCallbacks for FixedPskCallbacks { + fn psk_client_callback( + &self, + _max_identity_length: usize, + max_key_length: usize, + ) -> Option { + if self.key.len() > max_key_length { + return None; + } + + Some(PreSharedKeyClientCallbackResult { + identity: c"".into(), + key: self.key.clone(), + }) + } + + fn psk_server_callback(&self, _identity: &CStr, max_key_length: usize) -> Option> { + if self.key.len() > max_key_length { + return None; + } + + Some(self.key.clone()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/wolfssl/src/lib.rs b/wolfssl/src/lib.rs index 359684cf..ac2be015 100644 --- a/wolfssl/src/lib.rs +++ b/wolfssl/src/lib.rs @@ -241,6 +241,24 @@ impl Method { NonNull::new(ptr) } + + /// Returns true if this method is a server method + fn is_server(self) -> bool { + match self { + Self::DtlsClient => false, + Self::DtlsClientV1_2 => false, + Self::DtlsClientV1_3 => false, + Self::TlsClient => false, + Self::TlsClientV1_2 => false, + Self::TlsClientV1_3 => false, + Self::DtlsServer => true, + Self::DtlsServerV1_2 => true, + Self::DtlsServerV1_3 => true, + Self::TlsServer => true, + Self::TlsServerV1_2 => true, + Self::TlsServerV1_3 => true, + } + } } /// Corresponds to the various defined `WOLFSSL_*` curves diff --git a/wolfssl/src/ssl.rs b/wolfssl/src/ssl.rs index fbf9dc6e..7a4f5fef 100644 --- a/wolfssl/src/ssl.rs +++ b/wolfssl/src/ssl.rs @@ -1,6 +1,6 @@ use crate::{ callback::{IOCallbackResult, IOCallbacks}, - context::WolfsslPointer, + context::{PreSharedKeyCallbacks, WolfsslPointer}, error::{Error, Poll, PollResult, Result}, CurveGroup, ProtocolVersion, SslVerifyMode, TLS_MAX_RECORD_SIZE, }; @@ -10,6 +10,7 @@ use thiserror::Error; use std::{ ffi::{c_int, c_uchar, c_ushort, c_void}, + sync::Arc, time::Duration, }; @@ -170,6 +171,13 @@ pub struct Session { /// Box so we have a stable address to pass to FFI. io: Box, + /// The Arc is to ensure that the PreSharedKeyCallbacks live as long as any [Context] or + /// [Session] that is using them. We need to store a pointer to the PreSharedKeyCallbacks in the + /// WOLFSSL object in order to be able to actually call the callbacks, but we can't store an + /// `&dyn PreSharedKeyCallbacks` in the WOLFSSL context because that's a fat pointer and cannot + /// be cast to a C pointer. Instead, we use an extra `Box`, and store a pointer to that `Box` in + /// the WOLFSSL object. + pre_shared_key_callbacks: Option>>, #[cfg(feature = "debug")] secret_cb: Option>, @@ -194,10 +202,12 @@ impl Session { pub(crate) fn new_from_wolfssl_pointer( ssl: WolfsslPointer, config: SessionConfig, + pre_shared_key_callbacks: Option>>, ) -> std::result::Result { let mut session = Self { ssl, io: Box::new(config.io), + pre_shared_key_callbacks: pre_shared_key_callbacks.clone(), #[cfg(feature = "debug")] secret_cb: Default::default(), }; @@ -240,6 +250,11 @@ impl Session { session.set_verify(mode); } + // set_psk_callback_ctx uses the pre_shared_key_callbacks stored in the session object + session + .set_psk_callback_ctx() + .map_err(|e| NewSessionError::SetupFailed("set_psk_callback_ctx", e))?; + #[cfg(feature = "debug")] if let Some(keylogger) = config.keylogger { session @@ -366,6 +381,34 @@ impl Session { } } + /// Set the context pointer that will be available to PSK calbacks. + pub(crate) fn set_psk_callback_ctx(&mut self) -> Result<()> { + if let Some(psk_cbs) = self.pre_shared_key_callbacks.as_ref() { + let psk_cbs_ptr: *const Box = &**psk_cbs; + // SAFETY: No online docs. The implementation of `wolfSSL_set_psk_callback_ctx` simply + // assigns to `ssl->options.psk_ctx`. Per the [Library design][0] access is synchronized + // via the requirement for `&mut self` in `WelfsslPointer::as_ptr()`. Casting const to + // mut is safe since wolfSSL never mutates the callback ctx, and we don't either. + // + // The pre-shared key callbacks are guaranteed to last the lifetime of the SSL object + // because they're stored in the Session object. + // + // [0]: https://www.wolfssl.com/documentation/manuals/wolfssl/chapter09.html#thread-safety + match unsafe { + wolfssl_sys::wolfSSL_set_psk_callback_ctx( + self.ssl.as_ptr(), + psk_cbs_ptr as *mut c_void, + ) + } { + wolfssl_sys::WOLFSSL_SUCCESS_c_int => Ok(()), + e => Err(Error::fatal(e)), + } + } else { + // no psk -- do nothing + Ok(()) + } + } + /// Get a reference to the IOCB embedded in this session pub fn io_cb(&self) -> &IOCB { self.io.as_ref() @@ -1319,6 +1362,7 @@ impl Drop for Session { #[cfg(test)] mod tests { use super::*; + use crate::PreSharedKeyClientCallbackResult; use crate::{ context::ContextBuilder, Context, Method, RootCertificate, Secret, Session, TLS_MAX_RECORD_SIZE, @@ -1344,6 +1388,9 @@ mod tests { "/tests/data/server_key_der_2048" )); + const PSK: [u8; 8] = [0, 99, 8, 34, 2, 42, 3, 5]; + const PSK_IDENTITY: &std::ffi::CStr = c"test_identity"; + static INIT_ENV_LOGGER: OnceLock<()> = OnceLock::new(); // Panics if any I/O is attempted, use for tests where no I/O is expected @@ -1403,6 +1450,31 @@ mod tests { } } + #[derive(Debug, Clone)] + struct PskTrivialIdentityCallbacks {} + + impl PreSharedKeyCallbacks for PskTrivialIdentityCallbacks { + fn psk_client_callback( + &self, + _max_identity_length: usize, + _max_key_length: usize, + ) -> Option { + Some(PreSharedKeyClientCallbackResult { + identity: PSK_IDENTITY.into(), + key: PSK.into(), + }) + } + + fn psk_server_callback( + &self, + identity: &std::ffi::CStr, + _max_key_length: usize, + ) -> Option> { + assert!(identity == PSK_IDENTITY); + Some(PSK.into()) + } + } + struct TestClient { _ctx: Context, ssl: Session, @@ -1436,6 +1508,60 @@ mod tests { .unwrap() .build(); + make_connected_clients_from_contexts(client_ctx, server_ctx) + } + + /// Use a fixed pre-shared-key on both client and server. + fn make_connected_clients_with_method_psk( + client_method: Method, + server_method: Method, + ) -> (TestClient, TestClient) { + let client_ctx = ContextBuilder::new(client_method) + .unwrap_or_else(|e| panic!("new({client_method:?}): {e}")) + .with_pre_shared_key(&PSK) + .with_secure_renegotiation() + .unwrap() + .build(); + + let server_ctx = ContextBuilder::new(server_method) + .unwrap_or_else(|e| panic!("new({server_method:?}): {e}")) + .with_pre_shared_key(&PSK) + .with_secure_renegotiation() + .unwrap() + .build(); + + make_connected_clients_from_contexts(client_ctx, server_ctx) + } + + /// Unlike [make_connected_clients_with_method_psk], will use custom PSK callbacks to transmit + /// an "identity" from the client to the server, and verify it's the same on the server. + fn make_connected_clients_with_method_psk_custom_callbacks( + client_method: Method, + server_method: Method, + ) -> (TestClient, TestClient) { + let callbacks = Box::new(PskTrivialIdentityCallbacks {}); + + let client_ctx = ContextBuilder::new(client_method) + .unwrap_or_else(|e| panic!("new({client_method:?}): {e}")) + .with_pre_shared_key_callbacks(callbacks.clone()) + .with_secure_renegotiation() + .unwrap() + .build(); + + let server_ctx = ContextBuilder::new(server_method) + .unwrap_or_else(|e| panic!("new({server_method:?}): {e}")) + .with_pre_shared_key_callbacks(callbacks) + .with_secure_renegotiation() + .unwrap() + .build(); + + make_connected_clients_from_contexts(client_ctx, server_ctx) + } + + fn make_connected_clients_from_contexts( + client_ctx: Context, + server_ctx: Context, + ) -> (TestClient, TestClient) { let (client_io, server_io) = TestIOCallbacks::pair(); let client_read_buffer = client_io.r.clone(); @@ -1485,6 +1611,23 @@ mod tests { let _ = make_connected_clients(); } + #[test_case(Method::TlsClientV1_2, Method::TlsServerV1_2; "tls1.2")] + #[test_case(Method::TlsClientV1_3, Method::TlsServerV1_3; "tls1.3")] + fn try_negotiate_psk(client_method: Method, server_method: Method) { + INIT_ENV_LOGGER.get_or_init(env_logger::init); + + let _ = make_connected_clients_with_method_psk(client_method, server_method); + } + + #[test_case(Method::TlsClientV1_2, Method::TlsServerV1_2; "tls1.2")] + #[test_case(Method::TlsClientV1_3, Method::TlsServerV1_3; "tls1.3")] + fn try_negotiate_psk_custom_callbacks(client_method: Method, server_method: Method) { + INIT_ENV_LOGGER.get_or_init(env_logger::init); + + let _ = + make_connected_clients_with_method_psk_custom_callbacks(client_method, server_method); + } + #[test] fn try_write_trivial() { INIT_ENV_LOGGER.get_or_init(env_logger::init); @@ -1879,4 +2022,36 @@ mod tests { assert!(client_ssl.is_init_finished()); assert!(server_ssl.is_init_finished()); } + + #[test_case(Method::TlsClientV1_2, Method::TlsServerV1_2 => panics "verify mac problem")] + #[test_case(Method::TlsClientV1_3, Method::TlsServerV1_3 => panics "binder does not verify")] + fn test_wrong_psk(client_method: Method, server_method: Method) { + let client_ctx = ContextBuilder::new(client_method) + .unwrap_or_else(|e| panic!("new({client_method:?}): {e}")) + .with_pre_shared_key(&[1, 2, 3, 4]) + .with_secure_renegotiation() + .unwrap() + .build(); + + let server_ctx = ContextBuilder::new(server_method) + .unwrap_or_else(|e| panic!("new({server_method:?}): {e}")) + .with_pre_shared_key(&[4, 3, 2, 1]) + .with_secure_renegotiation() + .unwrap() + .build(); + + let (client_io, server_io) = TestIOCallbacks::pair(); + + let mut client_ssl = client_ctx + .new_session(SessionConfig::new(client_io)) + .unwrap(); + let mut server_ssl = server_ctx + .new_session(SessionConfig::new(server_io)) + .unwrap(); + + for _ in 0..7 { + let _ = client_ssl.try_negotiate().unwrap(); + let _ = server_ssl.try_negotiate().unwrap(); + } + } }