diff --git a/pgdog/src/lib.rs b/pgdog/src/lib.rs index ce4fb234..212d8d52 100644 --- a/pgdog/src/lib.rs +++ b/pgdog/src/lib.rs @@ -12,6 +12,7 @@ pub mod stats; #[cfg(feature = "tui")] pub mod tui; pub mod util; +pub mod wire_protocol; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; diff --git a/pgdog/src/wire_protocol/backend/authentication_cleartext_password.rs b/pgdog/src/wire_protocol/backend/authentication_cleartext_password.rs new file mode 100644 index 00000000..ccf88830 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_cleartext_password.rs @@ -0,0 +1,150 @@ +//! Module: wire_protocol::backend::authentication_cleartext_password +//! +//! Provides parsing and serialization for the AuthenticationCleartextPassword message ('R' with code 3) in the protocol. +//! +//! - `AuthenticationCleartextPasswordFrame`: represents the AuthenticationCleartextPassword message. +//! - `AuthenticationCleartextPasswordError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationCleartextPasswordFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationCleartextPasswordFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationCleartextPasswordError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), +} + +impl fmt::Display for AuthenticationCleartextPasswordError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationCleartextPasswordError::UnexpectedTag(t) => { + write!(f, "unexpected tag: {t:#X}") + } + AuthenticationCleartextPasswordError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationCleartextPasswordError::UnexpectedAuthCode(code) => { + write!(f, "unexpected auth code: {code}") + } + } + } +} + +impl StdError for AuthenticationCleartextPasswordError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationCleartextPasswordFrame { + type Error = AuthenticationCleartextPasswordError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationCleartextPasswordError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationCleartextPasswordError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 8 { + return Err(AuthenticationCleartextPasswordError::UnexpectedLength(len)); + } + + let code = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if code != 3 { + return Err(AuthenticationCleartextPasswordError::UnexpectedAuthCode( + code, + )); + } + + Ok(AuthenticationCleartextPasswordFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"R\x00\x00\x00\x08\x00\x00\x00\x03")) + } + + fn body_size(&self) -> usize { + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_auth_cleartext() { + let frame = AuthenticationCleartextPasswordFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x03"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_cleartext() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x03"; + let frame = AuthenticationCleartextPasswordFrame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_auth_cleartext() { + let original = AuthenticationCleartextPasswordFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationCleartextPasswordFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x03"; + let err = AuthenticationCleartextPasswordFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationCleartextPasswordError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x09\x00\x00\x00\x03"; + let err = AuthenticationCleartextPasswordFrame::from_bytes(data).unwrap_err(); + matches!( + err, + AuthenticationCleartextPasswordError::UnexpectedLength(_) + ); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x05"; + let err = AuthenticationCleartextPasswordFrame::from_bytes(data).unwrap_err(); + matches!( + err, + AuthenticationCleartextPasswordError::UnexpectedAuthCode(5) + ); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_gss.rs b/pgdog/src/wire_protocol/backend/authentication_gss.rs new file mode 100644 index 00000000..73d794fa --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_gss.rs @@ -0,0 +1,146 @@ +//! Module: wire_protocol::backend::authentication_gss +//! +//! Provides parsing and serialization for the AuthenticationGSS message ('R' with code 7) in the protocol. +//! +//! - `AuthenticationGssFrame`: represents the AuthenticationGSS message requesting GSSAPI authentication. +//! - `AuthenticationGssError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationGssFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationGssFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationGssError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthType(i32), +} + +impl fmt::Display for AuthenticationGssError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationGssError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationGssError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + AuthenticationGssError::UnexpectedAuthType(t) => write!(f, "unexpected auth type: {t}"), + } + } +} + +impl StdError for AuthenticationGssError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationGssFrame { + type Error = AuthenticationGssError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(AuthenticationGssError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationGssError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 8 { + return Err(AuthenticationGssError::UnexpectedLength(len)); + } + + if bytes.len() != 1 + len as usize { + return Err(AuthenticationGssError::UnexpectedLength(bytes.len() as u32)); + } + + let auth_type = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if auth_type != 7 { + return Err(AuthenticationGssError::UnexpectedAuthType(auth_type)); + } + + Ok(AuthenticationGssFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"R\x00\x00\x00\x08\x00\x00\x00\x07")) + } + + fn body_size(&self) -> usize { + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_authentication_gss() { + let frame = AuthenticationGssFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x07"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_authentication_gss() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x07"; + let frame = AuthenticationGssFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = frame; + } + + #[test] + fn roundtrip_authentication_gss() { + let original = AuthenticationGssFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationGssFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x07"; + let err = AuthenticationGssFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x09\x00\x00\x00\x07"; + let err = AuthenticationGssFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x07\x00"; + let err = AuthenticationGssFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_type() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x00"; + let err = AuthenticationGssFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssError::UnexpectedAuthType(0)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_gss_continue.rs b/pgdog/src/wire_protocol/backend/authentication_gss_continue.rs new file mode 100644 index 00000000..974b4d80 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_gss_continue.rs @@ -0,0 +1,221 @@ +//! Module: wire_protocol::backend::authentication_gss_continue +//! +//! Provides parsing and serialization for the AuthenticationGSSContinue message ('R' with code 8) in the protocol. +//! +//! - `AuthenticationGssContinueFrame`: represents the AuthenticationGSSContinue message with GSS/SSPI continuation data. +//! - `AuthenticationGssContinueError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationGssContinueFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthenticationGssContinueFrame<'a> { + pub data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationGssContinueError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthType(i32), + UnexpectedEof, +} + +impl fmt::Display for AuthenticationGssContinueError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationGssContinueError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationGssContinueError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationGssContinueError::UnexpectedAuthType(t) => { + write!(f, "unexpected auth type: {t}") + } + AuthenticationGssContinueError::UnexpectedEof => write!(f, "unexpected EOF"), + } + } +} + +impl StdError for AuthenticationGssContinueError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationGssContinueFrame<'a> { + type Error = AuthenticationGssContinueError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationGssContinueError::UnexpectedEof); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationGssContinueError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len < 8 { + return Err(AuthenticationGssContinueError::UnexpectedLength(len)); + } + + if bytes.len() != 1 + len as usize { + return Err(AuthenticationGssContinueError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let auth_type = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if auth_type != 8 { + return Err(AuthenticationGssContinueError::UnexpectedAuthType( + auth_type, + )); + } + + let data_len = (len - 8) as usize; + let buf = &bytes[9..]; + if buf.len() < data_len { + return Err(AuthenticationGssContinueError::UnexpectedEof); + } + + let data = &buf[0..data_len]; + + Ok(AuthenticationGssContinueFrame { data }) + } + + fn to_bytes(&self) -> Result { + let body_len = 4 + self.data.len(); + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'R'); + buf.put_u32(total_len as u32); + buf.put_i32(8); + buf.extend_from_slice(self.data); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self.data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame(data: &[u8]) -> AuthenticationGssContinueFrame { + AuthenticationGssContinueFrame { data } + } + + #[test] + fn serialize_authentication_gss_continue_empty() { + let frame = make_frame(&[]); + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x08"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn serialize_authentication_gss_continue_with_data() { + let data = b"some_gss_data"; + let frame = make_frame(data); + let bytes = frame.to_bytes().unwrap(); + let mut expected = Vec::new(); + expected.extend_from_slice(b"R"); + expected.extend_from_slice(&((4 + 4 + data.len()) as u32).to_be_bytes()); + expected.extend_from_slice(&8i32.to_be_bytes()); + expected.extend_from_slice(data); + assert_eq!(bytes.as_ref(), expected.as_slice()); + } + + #[test] + fn deserialize_authentication_gss_continue_empty() { + let data_bytes = b"R\x00\x00\x00\x08\x00\x00\x00\x08"; + let frame = AuthenticationGssContinueFrame::from_bytes(data_bytes).unwrap(); + let data = frame.data; + let expected_data: &[u8] = &[]; + assert_eq!(data, expected_data); + } + + #[test] + fn deserialize_authentication_gss_continue_with_data() { + let payload = b"some_gss_data"; + let mut data_bytes = Vec::new(); + data_bytes.extend_from_slice(b"R"); + data_bytes.extend_from_slice(&((4 + 4 + payload.len()) as u32).to_be_bytes()); + data_bytes.extend_from_slice(&8i32.to_be_bytes()); + data_bytes.extend_from_slice(payload); + let frame = AuthenticationGssContinueFrame::from_bytes(&data_bytes).unwrap(); + assert_eq!(frame.data, payload); + } + + #[test] + fn roundtrip_authentication_gss_continue_empty() { + let original = make_frame(&[]); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationGssContinueFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original.data, decoded.data); + } + + #[test] + fn roundtrip_authentication_gss_continue_with_data() { + let data = b"test_data_123"; + let original = make_frame(data); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationGssContinueFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original.data, decoded.data); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x08"; + let err = AuthenticationGssContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssContinueError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length_short() { + let data = b"R\x00\x00\x00\x07\x00\x00\x00\x08"; + let err = AuthenticationGssContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssContinueError::UnexpectedLength(_)); + } + + #[test] + fn invalid_length_mismatch() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x08\x00"; // extra byte + let err = AuthenticationGssContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssContinueError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_type() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x07"; + let err = AuthenticationGssContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationGssContinueError::UnexpectedAuthType(7)); + } + + #[test] + fn unexpected_eof() { + let data = b"R\x00\x00\x00\x0D\x00\x00\x00\x08abc"; // len=13, but only 3 data bytes + let err = AuthenticationGssContinueFrame::from_bytes(&data[0..12]).unwrap_err(); // truncate + matches!(err, AuthenticationGssContinueError::UnexpectedEof); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_kerberos_v5.rs b/pgdog/src/wire_protocol/backend/authentication_kerberos_v5.rs new file mode 100644 index 00000000..56647acc --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_kerberos_v5.rs @@ -0,0 +1,140 @@ +//! Module: wire_protocol::backend::authentication_kerberos_v5 +//! +//! Provides parsing and serialization for the AuthenticationKerberosV5 message ('R' with code 2) in the protocol. +//! +//! - `AuthenticationKerberosV5Frame`: represents the AuthenticationKerberosV5 message. +//! - `AuthenticationKerberosV5Error`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationKerberosV5Frame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationKerberosV5Frame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationKerberosV5Error { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), +} + +impl fmt::Display for AuthenticationKerberosV5Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationKerberosV5Error::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationKerberosV5Error::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationKerberosV5Error::UnexpectedAuthCode(code) => { + write!(f, "unexpected auth code: {code}") + } + } + } +} + +impl StdError for AuthenticationKerberosV5Error {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationKerberosV5Frame { + type Error = AuthenticationKerberosV5Error; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationKerberosV5Error::UnexpectedLength( + bytes.len() as u32 + )); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationKerberosV5Error::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 8 { + return Err(AuthenticationKerberosV5Error::UnexpectedLength(len)); + } + + let code = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if code != 2 { + return Err(AuthenticationKerberosV5Error::UnexpectedAuthCode(code)); + } + + Ok(AuthenticationKerberosV5Frame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"R\x00\x00\x00\x08\x00\x00\x00\x02")) + } + + fn body_size(&self) -> usize { + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_auth_kerberos_v5() { + let frame = AuthenticationKerberosV5Frame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x02"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_kerberos_v5() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x02"; + let frame = AuthenticationKerberosV5Frame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_auth_kerberos_v5() { + let original = AuthenticationKerberosV5Frame; + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationKerberosV5Frame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x02"; + let err = AuthenticationKerberosV5Frame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationKerberosV5Error::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x09\x00\x00\x00\x02"; + let err = AuthenticationKerberosV5Frame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationKerberosV5Error::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x03"; + let err = AuthenticationKerberosV5Frame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationKerberosV5Error::UnexpectedAuthCode(3)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_md5_password.rs b/pgdog/src/wire_protocol/backend/authentication_md5_password.rs new file mode 100644 index 00000000..4ab53090 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_md5_password.rs @@ -0,0 +1,177 @@ +//! Module: wire_protocol::backend::authentication_md5_password +//! +//! Provides parsing and serialization for the AuthenticationMD5Password message ('R' with code 5) in the protocol. +//! +//! - `AuthenticationMd5PasswordFrame`: represents the AuthenticationMD5Password message with a 4-byte salt. +//! - `AuthenticationMd5PasswordError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationMd5PasswordFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationMd5PasswordFrame { + pub salt: [u8; 4], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationMd5PasswordError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthType(i32), +} + +impl fmt::Display for AuthenticationMd5PasswordError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationMd5PasswordError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationMd5PasswordError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationMd5PasswordError::UnexpectedAuthType(t) => { + write!(f, "unexpected auth type: {t}") + } + } + } +} + +impl StdError for AuthenticationMd5PasswordError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationMd5PasswordFrame { + type Error = AuthenticationMd5PasswordError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 13 { + return Err(AuthenticationMd5PasswordError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationMd5PasswordError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 12 { + return Err(AuthenticationMd5PasswordError::UnexpectedLength(len)); + } + + if bytes.len() != 1 + len as usize { + return Err(AuthenticationMd5PasswordError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let auth_type = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if auth_type != 5 { + return Err(AuthenticationMd5PasswordError::UnexpectedAuthType( + auth_type, + )); + } + + let salt = [bytes[9], bytes[10], bytes[11], bytes[12]]; + + Ok(AuthenticationMd5PasswordFrame { salt }) + } + + fn to_bytes(&self) -> Result { + let mut buf = BytesMut::with_capacity(13); + buf.put_u8(b'R'); + buf.put_u32(12); + buf.put_i32(5); + buf.extend_from_slice(&self.salt); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 8 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> AuthenticationMd5PasswordFrame { + AuthenticationMd5PasswordFrame { + salt: [0x01, 0x02, 0x03, 0x04], + } + } + + #[test] + fn serialize_authentication_md5_password() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x0C\x00\x00\x00\x05\x01\x02\x03\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_authentication_md5_password() { + let data = b"R\x00\x00\x00\x0C\x00\x00\x00\x05\x01\x02\x03\x04"; + let frame = AuthenticationMd5PasswordFrame::from_bytes(data).unwrap(); + assert_eq!(frame.salt, [0x01, 0x02, 0x03, 0x04]); + } + + #[test] + fn roundtrip_authentication_md5_password() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationMd5PasswordFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x0C\x00\x00\x00\x05\x01\x02\x03\x04"; + let err = AuthenticationMd5PasswordFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationMd5PasswordError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x0D\x00\x00\x00\x05\x01\x02\x03\x04"; + let err = AuthenticationMd5PasswordFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationMd5PasswordError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"R\x00\x00\x00\x0C\x00\x00\x00\x05\x01\x02\x03\x04\x00"; + let err = AuthenticationMd5PasswordFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationMd5PasswordError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_type() { + let data = b"R\x00\x00\x00\x0C\x00\x00\x00\x07\x01\x02\x03\x04"; + let err = AuthenticationMd5PasswordFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationMd5PasswordError::UnexpectedAuthType(7)); + } + + #[test] + fn short_data() { + let data = b"R\x00\x00\x00\x0C\x00\x00\x00\x05\x01\x02\x03"; + let err = AuthenticationMd5PasswordFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationMd5PasswordError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_ok.rs b/pgdog/src/wire_protocol/backend/authentication_ok.rs new file mode 100644 index 00000000..4d58dd9d --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_ok.rs @@ -0,0 +1,136 @@ +//! Module: wire_protocol::backend::authentication_ok +//! +//! Provides parsing and serialization for the AuthenticationOK message ('R' with code 0) in the protocol. +//! +//! - `AuthenticationOkFrame`: represents the AuthenticationOK message. +//! - `AuthenticationOkError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationOkFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationOkFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationOkError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), +} + +impl fmt::Display for AuthenticationOkError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationOkError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationOkError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + AuthenticationOkError::UnexpectedAuthCode(code) => { + write!(f, "unexpected auth code: {code}") + } + } + } +} + +impl StdError for AuthenticationOkError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationOkFrame { + type Error = AuthenticationOkError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationOkError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationOkError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 8 { + return Err(AuthenticationOkError::UnexpectedLength(len)); + } + + let code = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if code != 0 { + return Err(AuthenticationOkError::UnexpectedAuthCode(code)); + } + + Ok(AuthenticationOkFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"R\x00\x00\x00\x08\x00\x00\x00\x00")) + } + + fn body_size(&self) -> usize { + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_auth_ok() { + let frame = AuthenticationOkFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_ok() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x00"; + let frame = AuthenticationOkFrame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_auth_ok() { + let original = AuthenticationOkFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationOkFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x00"; + let err = AuthenticationOkFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationOkError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x09\x00\x00\x00\x00"; + let err = AuthenticationOkFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationOkError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x03"; + let err = AuthenticationOkFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationOkError::UnexpectedAuthCode(3)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_sasl.rs b/pgdog/src/wire_protocol/backend/authentication_sasl.rs new file mode 100644 index 00000000..2433fd65 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_sasl.rs @@ -0,0 +1,246 @@ +//! Module: wire_protocol::backend::authentication_sasl +//! +//! Provides parsing and serialization for the AuthenticationSASL message ('R' with code 10) in the protocol. +//! +//! - `AuthenticationSaslFrame`: represents the AuthenticationSASL message with a list of supported mechanisms. +//! - `AuthenticationSaslError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationSaslFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::shared_property_types::{SaslMechanism, SaslMechanismError}; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthenticationSaslFrame { + pub mechanisms: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationSaslError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidLength, + MechanismError(SaslMechanismError), +} + +impl fmt::Display for AuthenticationSaslError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationSaslError::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + AuthenticationSaslError::UnexpectedLength(len) => { + write!(f, "unexpected length: {}", len) + } + AuthenticationSaslError::UnexpectedAuthCode(c) => { + write!(f, "unexpected auth code: {}", c) + } + AuthenticationSaslError::Utf8Error(e) => write!(f, "UTF-8 error: {}", e), + AuthenticationSaslError::UnexpectedEof => write!(f, "unexpected EOF"), + AuthenticationSaslError::InvalidLength => write!(f, "invalid length"), + AuthenticationSaslError::MechanismError(e) => write!(f, "SASL mechanism error: {}", e), + } + } +} + +impl StdError for AuthenticationSaslError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + AuthenticationSaslError::Utf8Error(e) => Some(e), + AuthenticationSaslError::MechanismError(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, AuthenticationSaslError> { + if let Some(pos) = buf.iter().position(|&b| b == 0) { + let (raw, rest) = buf.split_at(pos); + *buf = &rest[1..]; + return str::from_utf8(raw).map_err(AuthenticationSaslError::Utf8Error); + } + Err(AuthenticationSaslError::UnexpectedEof) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationSaslFrame { + type Error = AuthenticationSaslError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 10 { + return Err(AuthenticationSaslError::UnexpectedLength(bytes.len() as u32)); + } + if bytes[0] != b'R' { + return Err(AuthenticationSaslError::UnexpectedTag(bytes[0])); + } + let mut len_buf = &bytes[1..5]; + let len = len_buf.get_u32(); + if (bytes.len() - 1) != len as usize { + return Err(AuthenticationSaslError::UnexpectedLength(len)); + } + let mut rest = &bytes[5..]; + let code = rest.get_i32(); + if code != 10 { + return Err(AuthenticationSaslError::UnexpectedAuthCode(code)); + } + let mut body = &rest[..(len as usize - 8)]; + let mut mechanisms = Vec::new(); + loop { + let mech_str = read_cstr(&mut body)?; + if mech_str.is_empty() { + break; + } + let mech = SaslMechanism::from_str(mech_str) + .map_err(AuthenticationSaslError::MechanismError)?; + mechanisms.push(mech); + } + if !body.is_empty() { + return Err(AuthenticationSaslError::InvalidLength); + } + Ok(AuthenticationSaslFrame { mechanisms }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::new(); + body.put_i32(10); + for mech in &self.mechanisms { + body.extend_from_slice(mech.as_str().as_bytes()); + body.put_u8(0); + } + body.put_u8(0); + + let mut buf = BytesMut::with_capacity(1 + 4 + body.len()); + buf.put_u8(b'R'); + buf.put_u32((body.len() + 4) as u32); + buf.extend_from_slice(&body); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self + .mechanisms + .iter() + .map(|m| m.as_str().len() + 1) + .sum::() + + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> AuthenticationSaslFrame { + AuthenticationSaslFrame { + mechanisms: vec![SaslMechanism::ScramSha256], + } + } + + fn make_multi_frame() -> AuthenticationSaslFrame { + AuthenticationSaslFrame { + mechanisms: vec![SaslMechanism::ScramSha256, SaslMechanism::ScramSha256Plus], + } + } + + #[test] + fn serialize_auth_sasl() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x17\x00\x00\x00\x0ASCRAM-SHA-256\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_sasl() { + let data = b"R\x00\x00\x00\x17\x00\x00\x00\x0ASCRAM-SHA-256\x00\x00"; + let frame = AuthenticationSaslFrame::from_bytes(data).unwrap(); + assert_eq!(frame.mechanisms, vec![SaslMechanism::ScramSha256]); + } + + #[test] + fn roundtrip_auth_sasl() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationSaslFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.mechanisms, original.mechanisms); + } + + #[test] + fn roundtrip_multi_mechanisms() { + let original = make_multi_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationSaslFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.mechanisms, original.mechanisms); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x17\x00\x00\x00\x0ASCRAM-SHA-256\x00\x00"; + let err = AuthenticationSaslFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x18\x00\x00\x00\x0ASCRAM-SHA-256\x00\x00"; + let err = AuthenticationSaslFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x17\x00\x00\x00\x0BSCRAM-SHA-256\x00\x00"; + let err = AuthenticationSaslFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslError::UnexpectedAuthCode(11)); + } + + #[test] + fn missing_terminator() { + let data = b"R\x00\x00\x00\x16\x00\x00\x00\x0ASCRAM-SHA-256\x00"; + let err = AuthenticationSaslFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslError::UnexpectedEof); + } + + #[test] + fn extra_data_after_terminator() { + let data = b"R\x00\x00\x00\x18\x00\x00\x00\x0ASCRAM-SHA-256\x00\x00\x00"; + let err = AuthenticationSaslFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslError::InvalidLength); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[9] = 0xFF; + let err = AuthenticationSaslFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, AuthenticationSaslError::Utf8Error(_)); + } + + #[test] + fn unsupported_mechanism() { + let data = b"R\x00\x00\x00\x16\x00\x00\x00\x0AUNKNOWN-MECH\x00\x00"; + let err = AuthenticationSaslFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslError::MechanismError(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_sasl_continue.rs b/pgdog/src/wire_protocol/backend/authentication_sasl_continue.rs new file mode 100644 index 00000000..83d0a757 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_sasl_continue.rs @@ -0,0 +1,173 @@ +//! Module: wire_protocol::backend::authentication_sasl_continue +//! +//! Provides parsing and serialization for the AuthenticationSASLContinue message ('R' with code 11) in the protocol. +//! +//! - `AuthenticationSaslContinueFrame`: represents the AuthenticationSASLContinue message with SASL continuation data. +//! - `AuthenticationSaslContinueError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationSaslContinueFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthenticationSaslContinueFrame<'a> { + pub sasl_data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationSaslContinueError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), +} + +impl fmt::Display for AuthenticationSaslContinueError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationSaslContinueError::UnexpectedTag(t) => { + write!(f, "unexpected tag: {t:#X}") + } + AuthenticationSaslContinueError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationSaslContinueError::UnexpectedAuthCode(code) => { + write!(f, "unexpected auth code: {code}") + } + } + } +} + +impl StdError for AuthenticationSaslContinueError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationSaslContinueFrame<'a> { + type Error = AuthenticationSaslContinueError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationSaslContinueError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationSaslContinueError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if (len as usize) + 1 != bytes.len() { + return Err(AuthenticationSaslContinueError::UnexpectedLength(len)); + } + + let code = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if code != 11 { + return Err(AuthenticationSaslContinueError::UnexpectedAuthCode(code)); + } + + let sasl_data = &bytes[9..]; + + Ok(AuthenticationSaslContinueFrame { sasl_data }) + } + + fn to_bytes(&self) -> Result { + let body_len = 4 + self.sasl_data.len(); + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'R'); + buf.put_u32(total_len as u32); + buf.put_i32(11); + buf.extend_from_slice(self.sasl_data); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self.sasl_data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> AuthenticationSaslContinueFrame<'a> { + AuthenticationSaslContinueFrame { + sasl_data: b"r=some_nonce,s=salt,i=4096", + } + } + + #[test] + fn serialize_auth_sasl_continue() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + // length = 4 (len) + 4 (code) + 26 (data) = 34 => 0x22 + let expected = b"R\x00\x00\x00\x22\x00\x00\x00\x0Br=some_nonce,s=salt,i=4096"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_sasl_continue() { + let data = b"R\x00\x00\x00\x22\x00\x00\x00\x0Br=some_nonce,s=salt,i=4096"; + let frame = AuthenticationSaslContinueFrame::from_bytes(data).unwrap(); + assert_eq!(frame.sasl_data, b"r=some_nonce,s=salt,i=4096"); + } + + #[test] + fn roundtrip_auth_sasl_continue() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationSaslContinueFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.sasl_data, original.sasl_data); + } + + #[test] + fn empty_data() { + let frame = AuthenticationSaslContinueFrame { sasl_data: b"" }; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x0B"; + assert_eq!(bytes.as_ref(), expected); + + let decoded = AuthenticationSaslContinueFrame::from_bytes(bytes.as_ref()).unwrap(); + assert!(decoded.sasl_data.is_empty()); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x1F\x00\x00\x00\x0Br=some_nonce,s=salt,i=4096"; + let err = AuthenticationSaslContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslContinueError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x20\x00\x00\x00\x0Br=some_nonce,s=salt,i=4096"; + let err = AuthenticationSaslContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslContinueError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x1F\x00\x00\x00\x0Ar=some_nonce,s=salt,i=4096"; + let err = AuthenticationSaslContinueFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslContinueError::UnexpectedAuthCode(10)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_sasl_final.rs b/pgdog/src/wire_protocol/backend/authentication_sasl_final.rs new file mode 100644 index 00000000..ce8c0ad6 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_sasl_final.rs @@ -0,0 +1,171 @@ +//! Module: wire_protocol::backend::authentication_sasl_final +//! +//! Provides parsing and serialization for the AuthenticationSASLFinal message ('R' with code 12) in the protocol. +//! +//! - `AuthenticationSaslFinalFrame`: represents the AuthenticationSASLFinal message with SASL final data. +//! - `AuthenticationSaslFinalError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationSaslFinalFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AuthenticationSaslFinalFrame<'a> { + pub sasl_data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationSaslFinalError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), +} + +impl fmt::Display for AuthenticationSaslFinalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationSaslFinalError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationSaslFinalError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationSaslFinalError::UnexpectedAuthCode(code) => { + write!(f, "unexpected auth code: {code}") + } + } + } +} + +impl StdError for AuthenticationSaslFinalError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationSaslFinalFrame<'a> { + type Error = AuthenticationSaslFinalError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationSaslFinalError::UnexpectedLength( + bytes.len() as u32 + )); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationSaslFinalError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if (len as usize) + 1 != bytes.len() { + return Err(AuthenticationSaslFinalError::UnexpectedLength(len)); + } + + let code = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if code != 12 { + return Err(AuthenticationSaslFinalError::UnexpectedAuthCode(code)); + } + + let sasl_data = &bytes[9..]; + + Ok(AuthenticationSaslFinalFrame { sasl_data }) + } + + fn to_bytes(&self) -> Result { + let body_len = 4 + self.sasl_data.len(); + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'R'); + buf.put_u32(total_len as u32); + buf.put_i32(12); + buf.extend_from_slice(self.sasl_data); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self.sasl_data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> AuthenticationSaslFinalFrame<'a> { + AuthenticationSaslFinalFrame { + sasl_data: b"v=some_signature", + } + } + + #[test] + fn serialize_auth_sasl_final() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + // length = 24 (0x18) + let expected = b"R\x00\x00\x00\x18\x00\x00\x00\x0Cv=some_signature"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_sasl_final() { + let data = b"R\x00\x00\x00\x18\x00\x00\x00\x0Cv=some_signature"; + let frame = AuthenticationSaslFinalFrame::from_bytes(data).unwrap(); + assert_eq!(frame.sasl_data, b"v=some_signature"); + } + + #[test] + fn roundtrip_auth_sasl_final() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationSaslFinalFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.sasl_data, original.sasl_data); + } + + #[test] + fn empty_data() { + let frame = AuthenticationSaslFinalFrame { sasl_data: b"" }; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x0C"; + assert_eq!(bytes.as_ref(), expected); + + let decoded = AuthenticationSaslFinalFrame::from_bytes(bytes.as_ref()).unwrap(); + assert!(decoded.sasl_data.is_empty()); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x14\x00\x00\x00\x0Cv=some_signature"; + let err = AuthenticationSaslFinalFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslFinalError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x15\x00\x00\x00\x0Cv=some_signature"; + let err = AuthenticationSaslFinalFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslFinalError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x14\x00\x00\x00\x0Bv=some_signature"; + let err = AuthenticationSaslFinalFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSaslFinalError::UnexpectedAuthCode(11)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_scm_credential.rs b/pgdog/src/wire_protocol/backend/authentication_scm_credential.rs new file mode 100644 index 00000000..05c1a3d3 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_scm_credential.rs @@ -0,0 +1,158 @@ +//! Module: wire_protocol::backend::authentication_scm_credential +//! +//! Provides parsing and serialization for the AuthenticationSCMCredential message ('R' with code 6) in the protocol. +//! +//! - `AuthenticationScmCredentialFrame`: represents the AuthenticationSCMCredential message requesting SCM credential authentication. +//! - `AuthenticationScmCredentialError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationScmCredentialFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationScmCredentialFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationScmCredentialError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthType(i32), +} + +impl fmt::Display for AuthenticationScmCredentialError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationScmCredentialError::UnexpectedTag(t) => { + write!(f, "unexpected tag: {t:#X}") + } + AuthenticationScmCredentialError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + AuthenticationScmCredentialError::UnexpectedAuthType(t) => { + write!(f, "unexpected auth type: {t}") + } + } + } +} + +impl StdError for AuthenticationScmCredentialError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationScmCredentialFrame { + type Error = AuthenticationScmCredentialError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationScmCredentialError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationScmCredentialError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 8 { + return Err(AuthenticationScmCredentialError::UnexpectedLength(len)); + } + + if bytes.len() != 1 + len as usize { + return Err(AuthenticationScmCredentialError::UnexpectedLength( + bytes.len() as u32, + )); + } + + let auth_type = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if auth_type != 6 { + return Err(AuthenticationScmCredentialError::UnexpectedAuthType( + auth_type, + )); + } + + Ok(AuthenticationScmCredentialFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"R\x00\x00\x00\x08\x00\x00\x00\x06")) + } + + fn body_size(&self) -> usize { + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_authentication_scm_credential() { + let frame = AuthenticationScmCredentialFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x06"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_authentication_scm_credential() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x06"; + let frame = AuthenticationScmCredentialFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = frame; + } + + #[test] + fn roundtrip_authentication_scm_credential() { + let original = AuthenticationScmCredentialFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationScmCredentialFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x06"; + let err = AuthenticationScmCredentialFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationScmCredentialError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x09\x00\x00\x00\x06"; + let err = AuthenticationScmCredentialFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationScmCredentialError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x06\x00"; + let err = AuthenticationScmCredentialFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationScmCredentialError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_type() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x07"; + let err = AuthenticationScmCredentialFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationScmCredentialError::UnexpectedAuthType(7)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/authentication_sspi.rs b/pgdog/src/wire_protocol/backend/authentication_sspi.rs new file mode 100644 index 00000000..00bd0853 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/authentication_sspi.rs @@ -0,0 +1,136 @@ +//! Module: wire_protocol::backend::authentication_sspi +//! +//! Provides parsing and serialization for the AuthenticationSSPI message ('R' with code 9) in the protocol. +//! +//! - `AuthenticationSspiFrame`: represents the AuthenticationSSPI message. +//! - `AuthenticationSspiError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `AuthenticationSspiFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AuthenticationSspiFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum AuthenticationSspiError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedAuthCode(i32), +} + +impl fmt::Display for AuthenticationSspiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthenticationSspiError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + AuthenticationSspiError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + AuthenticationSspiError::UnexpectedAuthCode(code) => { + write!(f, "unexpected auth code: {code}") + } + } + } +} + +impl StdError for AuthenticationSspiError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for AuthenticationSspiFrame { + type Error = AuthenticationSspiError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(AuthenticationSspiError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'R' { + return Err(AuthenticationSspiError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 8 { + return Err(AuthenticationSspiError::UnexpectedLength(len)); + } + + let code = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + if code != 9 { + return Err(AuthenticationSspiError::UnexpectedAuthCode(code)); + } + + Ok(AuthenticationSspiFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"R\x00\x00\x00\x08\x00\x00\x00\x09")) + } + + fn body_size(&self) -> usize { + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_auth_sspi() { + let frame = AuthenticationSspiFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"R\x00\x00\x00\x08\x00\x00\x00\x09"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_auth_sspi() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x09"; + let frame = AuthenticationSspiFrame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_auth_sspi() { + let original = AuthenticationSspiFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = AuthenticationSspiFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\x00\x00\x00\x09"; + let err = AuthenticationSspiFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSspiError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"R\x00\x00\x00\x09\x00\x00\x00\x09"; + let err = AuthenticationSspiFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSspiError::UnexpectedLength(_)); + } + + #[test] + fn invalid_auth_code() { + let data = b"R\x00\x00\x00\x08\x00\x00\x00\x0A"; + let err = AuthenticationSspiFrame::from_bytes(data).unwrap_err(); + matches!(err, AuthenticationSspiError::UnexpectedAuthCode(10)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/backend_key_data.rs b/pgdog/src/wire_protocol/backend/backend_key_data.rs new file mode 100644 index 00000000..64fa5b73 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/backend_key_data.rs @@ -0,0 +1,160 @@ +//! Module: wire_protocol::backend::backend_key_data +//! +//! Provides parsing and serialization for the BackendKeyData message ('K') in the protocol. +//! +//! - `BackendKeyDataFrame`: represents the BackendKeyData message with process ID and secret key. +//! - `BackendKeyDataError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `BackendKeyDataFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct BackendKeyDataFrame { + pub process_id: i32, + pub secret_key: i32, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum BackendKeyDataError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for BackendKeyDataError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BackendKeyDataError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + BackendKeyDataError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for BackendKeyDataError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for BackendKeyDataFrame { + type Error = BackendKeyDataError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 13 { + return Err(BackendKeyDataError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'K' { + return Err(BackendKeyDataError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 12 { + return Err(BackendKeyDataError::UnexpectedLength(len)); + } + + if bytes.len() != 1 + len as usize { + return Err(BackendKeyDataError::UnexpectedLength(bytes.len() as u32)); + } + + let process_id = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + let secret_key = i32::from_be_bytes([bytes[9], bytes[10], bytes[11], bytes[12]]); + + Ok(BackendKeyDataFrame { + process_id, + secret_key, + }) + } + + fn to_bytes(&self) -> Result { + let mut buf = BytesMut::with_capacity(13); + buf.put_u8(b'K'); + buf.put_u32(12); + buf.put_i32(self.process_id); + buf.put_i32(self.secret_key); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 8 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> BackendKeyDataFrame { + BackendKeyDataFrame { + process_id: 1234, + secret_key: 5678, + } + } + + #[test] + fn serialize_backend_key_data() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"K\x00\x00\x00\x0C\x00\x00\x04\xD2\x00\x00\x16\x2E"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_backend_key_data() { + let data = b"K\x00\x00\x00\x0C\x00\x00\x04\xD2\x00\x00\x16\x2E"; + let frame = BackendKeyDataFrame::from_bytes(data).unwrap(); + assert_eq!(frame.process_id, 1234); + assert_eq!(frame.secret_key, 5678); + } + + #[test] + fn roundtrip_backend_key_data() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = BackendKeyDataFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"R\x00\x00\x00\x0C\x00\x00\x04\xD2\x00\x00\x16\x2E"; + let err = BackendKeyDataFrame::from_bytes(data).unwrap_err(); + matches!(err, BackendKeyDataError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"K\x00\x00\x00\x0D\x00\x00\x04\xD2\x00\x00\x16\x2E"; + let err = BackendKeyDataFrame::from_bytes(data).unwrap_err(); + matches!(err, BackendKeyDataError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"K\x00\x00\x00\x0C\x00\x00\x04\xD2\x00\x00\x16\x2E\x00"; + let err = BackendKeyDataFrame::from_bytes(data).unwrap_err(); + matches!(err, BackendKeyDataError::UnexpectedLength(_)); + } + + #[test] + fn short_data() { + let data = b"K\x00\x00\x00\x0C\x00\x00\x04\xD2\x00\x00\x16"; + let err = BackendKeyDataFrame::from_bytes(data).unwrap_err(); + matches!(err, BackendKeyDataError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/bind_complete.rs b/pgdog/src/wire_protocol/backend/bind_complete.rs new file mode 100644 index 00000000..ca1c3c8b --- /dev/null +++ b/pgdog/src/wire_protocol/backend/bind_complete.rs @@ -0,0 +1,132 @@ +//! Module: wire_protocol::backend::bind_complete +//! +//! Provides parsing and serialization for the BindComplete message ('2') in the protocol. +//! +//! - `BindCompleteFrame`: represents the BindComplete message indicating bind operation completion. +//! - `BindCompleteError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `BindCompleteFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct BindCompleteFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum BindCompleteError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for BindCompleteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindCompleteError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + BindCompleteError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for BindCompleteError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for BindCompleteFrame { + type Error = BindCompleteError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(BindCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + // tag must be '2' + if bytes[0] != b'2' { + return Err(BindCompleteError::UnexpectedTag(bytes[0])); + } + + // length field must be exactly 4 (no body) + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(BindCompleteError::UnexpectedLength(len)); + } + + // reject any extra or missing bytes beyond the 5-byte header + if bytes.len() != 1 + len as usize { + return Err(BindCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + Ok(BindCompleteFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"2\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_bind_complete() { + let frame = BindCompleteFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"2\x00\x00\x00\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_bind_complete() { + let data = b"2\x00\x00\x00\x04"; + let _ = BindCompleteFrame::from_bytes(data).unwrap(); + } + + #[test] + fn roundtrip_bind_complete() { + let original = BindCompleteFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = BindCompleteFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"3\x00\x00\x00\x04"; + let err = BindCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, BindCompleteError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"2\x00\x00\x00\x05"; + let err = BindCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, BindCompleteError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"2\x00\x00\x00\x04\x00"; + let err = BindCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, BindCompleteError::UnexpectedLength(6)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/close_complete.rs b/pgdog/src/wire_protocol/backend/close_complete.rs new file mode 100644 index 00000000..f2a53868 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/close_complete.rs @@ -0,0 +1,139 @@ +//! Module: wire_protocol::backend::close_complete +//! +//! Provides parsing and serialization for the CloseComplete message ('3') in the protocol. +//! +//! - `CloseCompleteFrame`: represents the CloseComplete message indicating close operation completion. +//! - `CloseCompleteError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CloseCompleteFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CloseCompleteFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CloseCompleteError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for CloseCompleteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CloseCompleteError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CloseCompleteError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for CloseCompleteError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CloseCompleteFrame { + type Error = CloseCompleteError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(CloseCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'3' { + return Err(CloseCompleteError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(CloseCompleteError::UnexpectedLength(len)); + } + + if bytes.len() != 5 { + return Err(CloseCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + Ok(CloseCompleteFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"3\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_close_complete() { + let frame = CloseCompleteFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"3\x00\x00\x00\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_close_complete() { + let data = b"3\x00\x00\x00\x04"; + let frame = CloseCompleteFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = frame; + } + + #[test] + fn roundtrip_close_complete() { + let original = CloseCompleteFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = CloseCompleteFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"2\x00\x00\x00\x04"; + let err = CloseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CloseCompleteError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"3\x00\x00\x00\x05"; + let err = CloseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CloseCompleteError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"3\x00\x00\x00\x04\x00"; + let err = CloseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CloseCompleteError::UnexpectedLength(_)); + } + + #[test] + fn short_data() { + let data = b"3\x00\x00\x00"; + let err = CloseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CloseCompleteError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/command_complete.rs b/pgdog/src/wire_protocol/backend/command_complete.rs new file mode 100644 index 00000000..39334a09 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/command_complete.rs @@ -0,0 +1,394 @@ +//! Module: wire_protocol::backend::command_complete +//! +//! Provides parsing and serialization for the CommandComplete message ('C') in the protocol. +//! +//! - `CommandCompleteFrame`: represents the CommandComplete message with the command tag. +//! - `CommandCompleteError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CommandCompleteFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq)] +pub struct CommandCompleteFrame<'a> { + pub command_tag: CommandTag<'a>, +} + +// ----------------------------------------------------------------------------- +// ----- Subproperties --------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq)] +pub enum CommandTag<'a> { + Insert { oid: u32, rows: u64 }, + Delete { rows: u64 }, + Update { rows: u64 }, + Merge { rows: u64 }, + Select { rows: u64 }, + Move { rows: u64 }, + Fetch { rows: u64 }, + Copy { rows: u64 }, + Other { tag: &'a str }, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CommandCompleteError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidCommandTagFormat, +} + +impl fmt::Display for CommandCompleteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CommandCompleteError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CommandCompleteError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + CommandCompleteError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + CommandCompleteError::UnexpectedEof => write!(f, "unexpected EOF"), + CommandCompleteError::InvalidCommandTagFormat => { + write!(f, "invalid command tag format") + } + } + } +} + +impl StdError for CommandCompleteError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + CommandCompleteError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(bytes: &'a [u8]) -> Result<(&'a str, usize), CommandCompleteError> { + let nul = bytes + .iter() + .position(|b| *b == 0) + .ok_or(CommandCompleteError::UnexpectedEof)?; + + let raw = &bytes[..nul]; + let s = str::from_utf8(raw).map_err(CommandCompleteError::Utf8Error)?; + + Ok((s, nul + 1)) +} + +fn num_digits(mut num: u64) -> usize { + if num == 0 { + 1 + } else { + let mut count = 0; + while num > 0 { + num /= 10; + count += 1; + } + count + } +} + +fn write_num(buf: &mut BytesMut, mut num: u64) { + if num == 0 { + buf.put_u8(b'0'); + return; + } + + let mut digits = [0u8; 20]; + let mut i = digits.len() - 1; + while num > 0 { + digits[i] = b'0' + (num % 10) as u8; + num /= 10; + i -= 1; + } + buf.extend_from_slice(&digits[i + 1..]); +} + +impl<'a> CommandTag<'a> { + fn parse(tag: &'a str) -> Result { + let parts: Vec<&str> = tag.split_whitespace().collect(); + match parts.as_slice() { + ["INSERT", oid_str, rows_str] => { + let oid = oid_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Insert { oid, rows }) + } + ["DELETE", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Delete { rows }) + } + ["UPDATE", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Update { rows }) + } + ["MERGE", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Merge { rows }) + } + ["SELECT", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Select { rows }) + } + ["MOVE", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Move { rows }) + } + ["FETCH", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Fetch { rows }) + } + ["COPY", rows_str] => { + let rows = rows_str + .parse::() + .map_err(|_| CommandCompleteError::InvalidCommandTagFormat)?; + Ok(Self::Copy { rows }) + } + _ => Ok(Self::Other { tag }), + } + } + + fn tag_len(&self) -> usize { + match self { + Self::Insert { oid, rows } => 7 + num_digits(*oid as u64) + 1 + num_digits(*rows), // "INSERT " + oid + " " + rows + Self::Delete { rows } => 7 + num_digits(*rows), // "DELETE " + rows + Self::Update { rows } => 7 + num_digits(*rows), // "UPDATE " + rows + Self::Merge { rows } => 6 + num_digits(*rows), // "MERGE " + rows + Self::Select { rows } => 7 + num_digits(*rows), // "SELECT " + rows + Self::Move { rows } => 5 + num_digits(*rows), // "MOVE " + rows + Self::Fetch { rows } => 6 + num_digits(*rows), // "FETCH " + rows + Self::Copy { rows } => 5 + num_digits(*rows), // "COPY " + rows + Self::Other { tag } => tag.len(), + } + } + + fn write_to(&self, buf: &mut BytesMut) { + match self { + Self::Insert { oid, rows } => { + buf.extend_from_slice(b"INSERT "); + write_num(buf, *oid as u64); + buf.put_u8(b' '); + write_num(buf, *rows); + } + Self::Delete { rows } => { + buf.extend_from_slice(b"DELETE "); + write_num(buf, *rows); + } + Self::Update { rows } => { + buf.extend_from_slice(b"UPDATE "); + write_num(buf, *rows); + } + Self::Merge { rows } => { + buf.extend_from_slice(b"MERGE "); + write_num(buf, *rows); + } + Self::Select { rows } => { + buf.extend_from_slice(b"SELECT "); + write_num(buf, *rows); + } + Self::Move { rows } => { + buf.extend_from_slice(b"MOVE "); + write_num(buf, *rows); + } + Self::Fetch { rows } => { + buf.extend_from_slice(b"FETCH "); + write_num(buf, *rows); + } + Self::Copy { rows } => { + buf.extend_from_slice(b"COPY "); + write_num(buf, *rows); + } + Self::Other { tag } => { + buf.extend_from_slice(tag.as_bytes()); + } + } + } +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CommandCompleteFrame<'a> { + type Error = CommandCompleteError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(CommandCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'C' { + return Err(CommandCompleteError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(CommandCompleteError::UnexpectedLength(len)); + } + + let (command_tag_str, consumed) = read_cstr(&bytes[5..])?; + + if consumed != bytes.len() - 5 { + return Err(CommandCompleteError::UnexpectedLength(len)); + } + + let command_tag = CommandTag::parse(command_tag_str)?; + + Ok(CommandCompleteFrame { command_tag }) + } + + fn to_bytes(&self) -> Result { + let tag_len = self.command_tag.tag_len(); + let body_len = tag_len + 1; + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'C'); + buf.put_u32(total_len as u32); + self.command_tag.write_to(&mut buf); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.command_tag.tag_len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_select_frame() -> CommandCompleteFrame<'static> { + CommandCompleteFrame { + command_tag: CommandTag::Select { rows: 1 }, + } + } + + fn make_insert_frame() -> CommandCompleteFrame<'static> { + CommandCompleteFrame { + command_tag: CommandTag::Insert { oid: 0, rows: 1 }, + } + } + + fn make_begin_frame() -> CommandCompleteFrame<'static> { + CommandCompleteFrame { + command_tag: CommandTag::Other { tag: "BEGIN" }, + } + } + + #[test] + fn serialize_command_complete() { + let frame = make_select_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"C\x00\x00\x00\x0DSELECT 1\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_command_complete() { + let data = b"C\x00\x00\x00\x0DSELECT 1\x00"; + let frame = CommandCompleteFrame::from_bytes(data).unwrap(); + assert_eq!(frame.command_tag, CommandTag::Select { rows: 1 }); + } + + #[test] + fn roundtrip_command_complete() { + let original = make_select_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CommandCompleteFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.command_tag, original.command_tag); + } + + #[test] + fn roundtrip_insert() { + let original = make_insert_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CommandCompleteFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.command_tag, original.command_tag); + } + + #[test] + fn roundtrip_other() { + let original = make_begin_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CommandCompleteFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.command_tag, original.command_tag); + } + + #[test] + fn invalid_tag() { + let data = b"Q\x00\x00\x00\x0DSELECT 1\x00"; + let err = CommandCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CommandCompleteError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"C\x00\x00\x00\x0ESELECT 1\x00"; + let err = CommandCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CommandCompleteError::UnexpectedLength(_)); + } + + #[test] + fn missing_null_terminator() { + let data = b"C\x00\x00\x00\x0DSELECT 1"; + let err = CommandCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CommandCompleteError::UnexpectedEof); + } + + #[test] + fn extra_data_after_null() { + let data = b"C\x00\x00\x00\x0DSELECT 1\x00extra"; + let err = CommandCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CommandCompleteError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_select_frame().to_bytes().unwrap().to_vec(); + bytes[5] = 0xFF; // corrupt first byte + let err = CommandCompleteFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, CommandCompleteError::Utf8Error(_)); + } + + #[test] + fn invalid_command_tag_format() { + let data = b"C\x00\x00\x00\x0CDELETE abc\x00"; + let err = CommandCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, CommandCompleteError::InvalidCommandTagFormat); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/copy_both_response.rs b/pgdog/src/wire_protocol/backend/copy_both_response.rs new file mode 100644 index 00000000..4f1c8b29 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/copy_both_response.rs @@ -0,0 +1,234 @@ +//! Module: wire_protocol::backend::copy_both_response +//! +//! Provides parsing and serialization for the CopyBothResponse message ('W') in the protocol. +//! +//! - `CopyBothResponseFrame`: represents the CopyBothResponse message with overall format and column formats. +//! - `CopyBothResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CopyBothResponseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::shared_property_types::ResultFormat; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CopyBothResponseFrame { + pub format: ResultFormat, + pub column_formats: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CopyBothResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + InvalidFormatCode(i8), + InvalidColumnFormatCode(i16), +} + +impl fmt::Display for CopyBothResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyBothResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CopyBothResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + CopyBothResponseError::InvalidFormatCode(c) => write!(f, "invalid format code: {c}"), + CopyBothResponseError::InvalidColumnFormatCode(c) => { + write!(f, "invalid column format code: {c}") + } + } + } +} + +impl StdError for CopyBothResponseError {} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn decode_format_code(code: i16) -> Result { + match code { + 0 => Ok(ResultFormat::Text), + 1 => Ok(ResultFormat::Binary), + other => Err(CopyBothResponseError::InvalidColumnFormatCode(other)), + } +} + +fn encode_format_code(buf: &mut BytesMut, format: ResultFormat) { + buf.put_i16(match format { + ResultFormat::Text => 0, + ResultFormat::Binary => 1, + }); +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CopyBothResponseFrame { + type Error = CopyBothResponseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(CopyBothResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes.get_u8(); + if tag != b'W' { + return Err(CopyBothResponseError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len as usize != bytes.remaining() + 4 { + return Err(CopyBothResponseError::UnexpectedLength(len)); + } + + let format_code = bytes.get_i8(); + let format = match format_code { + 0 => ResultFormat::Text, + 1 => ResultFormat::Binary, + other => return Err(CopyBothResponseError::InvalidFormatCode(other)), + }; + + let num_cols = bytes.get_i16() as usize; + + if bytes.remaining() != num_cols * 2 { + return Err(CopyBothResponseError::UnexpectedLength(len)); + } + + let mut column_formats = Vec::with_capacity(num_cols); + for _ in 0..num_cols { + let col_code = bytes.get_i16(); + column_formats.push(decode_format_code(col_code)?); + } + + Ok(CopyBothResponseFrame { + format, + column_formats, + }) + } + + fn to_bytes(&self) -> Result { + let body_size = 1 + 2 + self.column_formats.len() * 2; + let total_len = 4 + body_size; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'W'); + buf.put_u32(total_len as u32); + buf.put_i8(match self.format { + ResultFormat::Text => 0, + ResultFormat::Binary => 1, + }); + buf.put_i16(self.column_formats.len() as i16); + for fmt in &self.column_formats { + encode_format_code(&mut buf, *fmt); + } + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 1 + 2 + self.column_formats.len() * 2 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_text_frame() -> CopyBothResponseFrame { + CopyBothResponseFrame { + format: ResultFormat::Text, + column_formats: vec![ResultFormat::Text, ResultFormat::Text], + } + } + + fn make_binary_frame() -> CopyBothResponseFrame { + CopyBothResponseFrame { + format: ResultFormat::Binary, + column_formats: vec![ResultFormat::Binary], + } + } + + #[test] + fn serialize_text() { + let frame = make_text_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"W\x00\x00\x00\x0B\x00\x00\x02\x00\x00\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_text() { + let data = b"W\x00\x00\x00\x0B\x00\x00\x02\x00\x00\x00\x00"; + let frame = CopyBothResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.format, ResultFormat::Text); + assert_eq!( + frame.column_formats, + vec![ResultFormat::Text, ResultFormat::Text] + ); + } + + #[test] + fn roundtrip_text() { + let original = make_text_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyBothResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.format, original.format); + assert_eq!(decoded.column_formats, original.column_formats); + } + + #[test] + fn roundtrip_binary() { + let original = make_binary_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyBothResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.format, original.format); + assert_eq!(decoded.column_formats, original.column_formats); + } + + #[test] + fn invalid_tag() { + let data = b"H\x00\x00\x00\x0B\x00\x00\x02\x00\x00\x00\x00"; + let err = CopyBothResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyBothResponseError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"W\x00\x00\x00\x0C\x00\x00\x02\x00\x00\x00\x00"; + let err = CopyBothResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyBothResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_format_code() { + let data = b"W\x00\x00\x00\x09\x02\x00\x01\x00\x00"; + let err = CopyBothResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyBothResponseError::InvalidFormatCode(2)); + } + + #[test] + fn invalid_column_format_code() { + let data = b"W\x00\x00\x00\x09\x00\x00\x01\x00\x02"; + let err = CopyBothResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyBothResponseError::InvalidColumnFormatCode(2)); + } + + #[test] + fn short_column_formats() { + let data = b"W\x00\x00\x00\x0B\x00\x00\x02\x00\x00"; // missing last i16 + let err = CopyBothResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyBothResponseError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/copy_data.rs b/pgdog/src/wire_protocol/backend/copy_data.rs new file mode 100644 index 00000000..5e7ce679 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/copy_data.rs @@ -0,0 +1,6 @@ +//! Module: wire_protocol::backend::copy_data +//! +//! Re-exports the bidirectional CopyDataFrame and CopyDataError +//! to avoid duplicating the implementation. + +pub use crate::wire_protocol::bidirectional::copy_data::{CopyDataError, CopyDataFrame}; diff --git a/pgdog/src/wire_protocol/backend/copy_done.rs b/pgdog/src/wire_protocol/backend/copy_done.rs new file mode 100644 index 00000000..f09c9a76 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/copy_done.rs @@ -0,0 +1,6 @@ +//! Module: wire_protocol::backend::copy_done +//! +//! Re-exports the bidirectional CopyDoneFrame and CopyDoneError +//! to avoid duplicating the implementation. + +pub use crate::wire_protocol::bidirectional::copy_done::{CopyDoneError, CopyDoneFrame}; diff --git a/pgdog/src/wire_protocol/backend/copy_in_response.rs b/pgdog/src/wire_protocol/backend/copy_in_response.rs new file mode 100644 index 00000000..756e7a30 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/copy_in_response.rs @@ -0,0 +1,276 @@ +//! Module: wire_protocol::backend::copy_in_response +//! +//! Provides parsing and serialization for the CopyInResponse message ('G') in the protocol. +//! +//! - `CopyInResponseFrame`: represents the CopyInResponse message with overall format and per-column formats. +//! - `CopyInResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CopyInResponseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::shared_property_types::ResultFormat; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CopyInResponseFrame { + pub overall_format: ResultFormat, + pub column_formats: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CopyInResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedEof, + InvalidOverallFormat(i8), + InvalidColumnFormat(i16), +} + +impl fmt::Display for CopyInResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyInResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CopyInResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + CopyInResponseError::UnexpectedEof => write!(f, "unexpected EOF"), + CopyInResponseError::InvalidOverallFormat(c) => { + write!(f, "invalid overall format code: {c}") + } + CopyInResponseError::InvalidColumnFormat(c) => { + write!(f, "invalid column format code: {c}") + } + } + } +} + +impl StdError for CopyInResponseError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CopyInResponseFrame { + type Error = CopyInResponseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + // need at least tag (1) + len (4) + if bytes.remaining() < 5 { + return Err(CopyInResponseError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b'G' { + return Err(CopyInResponseError::UnexpectedTag(tag)); + } + + // need length field + if bytes.remaining() < 4 { + return Err(CopyInResponseError::UnexpectedEof); + } + let len = bytes.get_u32(); + // minimum frame length = 4 (len field) + 1 (overall) + 2 (count) = 7 + if len < 7 { + return Err(CopyInResponseError::UnexpectedLength(len)); + } + let payload_len = (len - 4) as usize; + let rem = bytes.remaining(); + if rem < payload_len { + return Err(CopyInResponseError::UnexpectedEof); + } + if rem > payload_len { + return Err(CopyInResponseError::UnexpectedLength(len)); + } + + // now parse payload + let overall_code = bytes.get_i8(); + let overall_format = match overall_code { + 0 => ResultFormat::Text, + 1 => ResultFormat::Binary, + c => return Err(CopyInResponseError::InvalidOverallFormat(c)), + }; + + let num_i16 = bytes.get_i16(); + if num_i16 < 0 { + return Err(CopyInResponseError::UnexpectedLength(len)); + } + let num = num_i16 as usize; + + // expect exactly 2*num bytes left for column formats + if bytes.remaining() < 2 * num { + return Err(CopyInResponseError::UnexpectedEof); + } + + let mut column_formats = Vec::with_capacity(num); + for _ in 0..num { + let code = bytes.get_i16(); + let fmt = match code { + 0 => ResultFormat::Text, + 1 => ResultFormat::Binary, + c => return Err(CopyInResponseError::InvalidColumnFormat(c)), + }; + column_formats.push(fmt); + } + + Ok(CopyInResponseFrame { + overall_format, + column_formats, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + body.put_i8(if matches!(self.overall_format, ResultFormat::Text) { + 0 + } else { + 1 + }); + body.put_i16(self.column_formats.len() as i16); + for fmt in &self.column_formats { + body.put_i16(if matches!(fmt, ResultFormat::Text) { + 0 + } else { + 1 + }); + } + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'G'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 1 + 2 + 2 * self.column_formats.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_text_frame() -> CopyInResponseFrame { + CopyInResponseFrame { + overall_format: ResultFormat::Text, + column_formats: vec![ResultFormat::Text, ResultFormat::Text], + } + } + + fn make_binary_frame() -> CopyInResponseFrame { + CopyInResponseFrame { + overall_format: ResultFormat::Binary, + column_formats: vec![ResultFormat::Text, ResultFormat::Binary], + } + } + + #[test] + fn serialize_copy_in_response_text() { + let frame = make_text_frame(); + let bytes = frame.to_bytes().unwrap(); + // 'G' + length(4 + 1 + 2 + 4 = 11) + overall(0) + count(2) + fmt(0,0) + let expected = &[b'G', 0, 0, 0, 11, 0, 0, 2, 0, 0, 0, 0]; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn serialize_copy_in_response_binary() { + let frame = make_binary_frame(); + let bytes = frame.to_bytes().unwrap(); + // 'G' + length(11) + overall(1) + count(2) + fmt(0,1) + let expected = &[b'G', 0, 0, 0, 11, 1, 0, 2, 0, 0, 0, 1]; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_copy_in_response_text() { + let data = &[b'G', 0, 0, 0, 11, 0, 0, 2, 0, 0, 0, 0]; + let frame = CopyInResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.overall_format, ResultFormat::Text); + assert_eq!( + frame.column_formats, + vec![ResultFormat::Text, ResultFormat::Text] + ); + } + + #[test] + fn deserialize_copy_in_response_binary() { + let data = &[b'G', 0, 0, 0, 11, 1, 0, 2, 0, 0, 0, 1]; + let frame = CopyInResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.overall_format, ResultFormat::Binary); + assert_eq!( + frame.column_formats, + vec![ResultFormat::Text, ResultFormat::Binary] + ); + } + + #[test] + fn roundtrip_copy_in_response_text() { + let original = make_text_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyInResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn roundtrip_copy_in_response_binary() { + let original = make_binary_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyInResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = vec![b'H', 0, 0, 0, 11, 0, 0, 2, 0, 0, 0, 0]; + let err = CopyInResponseFrame::from_bytes(&data).unwrap_err(); + assert!(matches!(err, CopyInResponseError::UnexpectedTag(_))); + } + + #[test] + fn invalid_length_short() { + let data = &[b'G', 0, 0, 0, 6]; + let err = CopyInResponseFrame::from_bytes(data).unwrap_err(); + assert!(matches!(err, CopyInResponseError::UnexpectedLength(_))); + } + + #[test] + fn invalid_length_mismatch() { + let data = vec![b'G', 0, 0, 0, 11, 0, 0, 2, 0, 0, 0, 0, 0]; // extra byte + let err = CopyInResponseFrame::from_bytes(&data).unwrap_err(); + assert!(matches!(err, CopyInResponseError::UnexpectedLength(_))); + } + + #[test] + fn unexpected_eof() { + let data = &[b'G', 0, 0, 0, 11, 0, 0, 2, 0, 0]; + let err = CopyInResponseFrame::from_bytes(data).unwrap_err(); + assert!(matches!(err, CopyInResponseError::UnexpectedEof)); + } + + #[test] + fn invalid_overall_format() { + let data = &[b'G', 0, 0, 0, 7, 2, 0, 0]; + let err = CopyInResponseFrame::from_bytes(data).unwrap_err(); + assert!(matches!(err, CopyInResponseError::InvalidOverallFormat(2))); + } + + #[test] + fn invalid_column_format() { + let data = &[b'G', 0, 0, 0, 9, 1, 0, 1, 0, 2]; + let err = CopyInResponseFrame::from_bytes(data).unwrap_err(); + assert!(matches!(err, CopyInResponseError::InvalidColumnFormat(2))); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/copy_out_response.rs b/pgdog/src/wire_protocol/backend/copy_out_response.rs new file mode 100644 index 00000000..96775833 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/copy_out_response.rs @@ -0,0 +1,234 @@ +//! Module: wire_protocol::backend::copy_out_response +//! +//! Provides parsing and serialization for the CopyOutResponse message ('G') in the protocol. +//! +//! - `CopyOutResponseFrame`: represents the CopyOutResponse message with overall format and column formats. +//! - `CopyOutResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CopyOutResponseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::shared_property_types::ResultFormat; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CopyOutResponseFrame { + pub format: ResultFormat, + pub column_formats: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CopyOutResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + InvalidFormatCode(i8), + InvalidColumnFormatCode(i16), +} + +impl fmt::Display for CopyOutResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyOutResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CopyOutResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + CopyOutResponseError::InvalidFormatCode(c) => write!(f, "invalid format code: {c}"), + CopyOutResponseError::InvalidColumnFormatCode(c) => { + write!(f, "invalid column format code: {c}") + } + } + } +} + +impl StdError for CopyOutResponseError {} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn decode_format_code(code: i16) -> Result { + match code { + 0 => Ok(ResultFormat::Text), + 1 => Ok(ResultFormat::Binary), + other => Err(CopyOutResponseError::InvalidColumnFormatCode(other)), + } +} + +fn encode_format_code(buf: &mut BytesMut, format: ResultFormat) { + buf.put_i16(match format { + ResultFormat::Text => 0, + ResultFormat::Binary => 1, + }); +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CopyOutResponseFrame { + type Error = CopyOutResponseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(CopyOutResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes.get_u8(); + if tag != b'G' { + return Err(CopyOutResponseError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len as usize != bytes.remaining() + 4 { + return Err(CopyOutResponseError::UnexpectedLength(len)); + } + + let format_code = bytes.get_i8(); + let format = match format_code { + 0 => ResultFormat::Text, + 1 => ResultFormat::Binary, + other => return Err(CopyOutResponseError::InvalidFormatCode(other)), + }; + + let num_cols = bytes.get_i16() as usize; + + if bytes.remaining() != num_cols * 2 { + return Err(CopyOutResponseError::UnexpectedLength(len)); + } + + let mut column_formats = Vec::with_capacity(num_cols); + for _ in 0..num_cols { + let col_code = bytes.get_i16(); + column_formats.push(decode_format_code(col_code)?); + } + + Ok(CopyOutResponseFrame { + format, + column_formats, + }) + } + + fn to_bytes(&self) -> Result { + let body_size = 1 + 2 + self.column_formats.len() * 2; + let total_len = 4 + body_size; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'G'); + buf.put_u32(total_len as u32); + buf.put_i8(match self.format { + ResultFormat::Text => 0, + ResultFormat::Binary => 1, + }); + buf.put_i16(self.column_formats.len() as i16); + for fmt in &self.column_formats { + encode_format_code(&mut buf, *fmt); + } + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 1 + 2 + self.column_formats.len() * 2 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_text_frame() -> CopyOutResponseFrame { + CopyOutResponseFrame { + format: ResultFormat::Text, + column_formats: vec![ResultFormat::Text, ResultFormat::Text], + } + } + + fn make_binary_frame() -> CopyOutResponseFrame { + CopyOutResponseFrame { + format: ResultFormat::Binary, + column_formats: vec![ResultFormat::Binary], + } + } + + #[test] + fn serialize_text() { + let frame = make_text_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"G\x00\x00\x00\x0B\x00\x00\x02\x00\x00\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_text() { + let data = b"G\x00\x00\x00\x0B\x00\x00\x02\x00\x00\x00\x00"; + let frame = CopyOutResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.format, ResultFormat::Text); + assert_eq!( + frame.column_formats, + vec![ResultFormat::Text, ResultFormat::Text] + ); + } + + #[test] + fn roundtrip_text() { + let original = make_text_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyOutResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.format, original.format); + assert_eq!(decoded.column_formats, original.column_formats); + } + + #[test] + fn roundtrip_binary() { + let original = make_binary_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyOutResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.format, original.format); + assert_eq!(decoded.column_formats, original.column_formats); + } + + #[test] + fn invalid_tag() { + let data = b"H\x00\x00\x00\x0B\x00\x00\x02\x00\x00\x00\x00"; + let err = CopyOutResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyOutResponseError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"G\x00\x00\x00\x0C\x00\x00\x02\x00\x00\x00\x00"; + let err = CopyOutResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyOutResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_format_code() { + let data = b"G\x00\x00\x00\x09\x02\x00\x01\x00\x00"; + let err = CopyOutResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyOutResponseError::InvalidFormatCode(2)); + } + + #[test] + fn invalid_column_format_code() { + let data = b"G\x00\x00\x00\x09\x00\x00\x01\x00\x02"; + let err = CopyOutResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyOutResponseError::InvalidColumnFormatCode(2)); + } + + #[test] + fn short_column_formats() { + let data = b"G\x00\x00\x00\x0B\x00\x00\x02\x00\x00"; // missing last i16 + let err = CopyOutResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyOutResponseError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/data_row.rs b/pgdog/src/wire_protocol/backend/data_row.rs new file mode 100644 index 00000000..5100f6a1 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/data_row.rs @@ -0,0 +1,277 @@ +//! Module: wire_protocol::backend::data_row +//! +//! Provides parsing and serialization for the DataRow message ('D') in the protocol. +//! +//! - `DataRowFrame`: represents the DataRow message with column values. +//! - `DataRowError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `DataRowFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DataRowFrame<'a> { + pub columns: Vec>, +} + +// ----------------------------------------------------------------------------- +// ----- Subproperties --------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ColumnValue<'a> { + Null, + Value(&'a [u8]), +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum DataRowError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedEof, + InvalidColumnLength(i32), +} + +impl fmt::Display for DataRowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DataRowError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + DataRowError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + DataRowError::UnexpectedEof => write!(f, "unexpected EOF"), + DataRowError::InvalidColumnLength(len) => write!(f, "invalid column length: {len}"), + } + } +} + +impl StdError for DataRowError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for DataRowFrame<'a> { + type Error = DataRowError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 7 { + return Err(DataRowError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b'D' { + return Err(DataRowError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len < 6 { + return Err(DataRowError::UnexpectedLength(len)); + } + + if bytes.remaining() != (len - 4) as usize { + return Err(DataRowError::UnexpectedLength(len)); + } + + let num_columns = bytes.get_i16(); + if num_columns < 0 { + return Err(DataRowError::UnexpectedLength(len)); + } + + let num = num_columns as usize; + let mut columns = Vec::with_capacity(num); + + for _ in 0..num { + if bytes.remaining() < 4 { + return Err(DataRowError::UnexpectedEof); + } + + let col_len = bytes.get_i32(); + let col_val = if col_len == -1 { + ColumnValue::Null + } else if col_len < 0 { + return Err(DataRowError::InvalidColumnLength(col_len)); + } else { + let col_len_usize = col_len as usize; + if bytes.remaining() < col_len_usize { + return Err(DataRowError::UnexpectedEof); + } + + let value = &bytes[0..col_len_usize]; + bytes = &bytes[col_len_usize..]; + ColumnValue::Value(value) + }; + + columns.push(col_val); + } + + if bytes.has_remaining() { + return Err(DataRowError::UnexpectedLength(len)); + } + + Ok(DataRowFrame { columns }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + body.put_i16(self.columns.len() as i16); + for col in &self.columns { + match col { + ColumnValue::Null => body.put_i32(-1), + ColumnValue::Value(val) => { + body.put_i32(val.len() as i32); + body.extend_from_slice(val); + } + } + } + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'D'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 2 + self + .columns + .iter() + .map(|col| { + 4 + match col { + ColumnValue::Null => 0, + ColumnValue::Value(val) => val.len(), + } + }) + .sum::() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> DataRowFrame<'a> { + DataRowFrame { + columns: vec![ + ColumnValue::Null, + ColumnValue::Value(b"col2_value"), + ColumnValue::Value(&[]), // empty + ], + } + } + + #[test] + fn serialize_data_row() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + // 'D' + len(4 + 2 + 4*3 + 10 + 0) = len=4+2+12+10=28, u32=28 + // i16=3, -1, 10 + "col2_value", 0 + let expected = + b"D\x00\x00\x00\x1C\x00\x03\xff\xff\xff\xff\x00\x00\x00\x0Acol2_value\x00\x00\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_data_row() { + let data = + b"D\x00\x00\x00\x1C\x00\x03\xff\xff\xff\xff\x00\x00\x00\x0Acol2_value\x00\x00\x00\x00"; + let frame = DataRowFrame::from_bytes(data).unwrap(); + assert_eq!(frame.columns.len(), 3); + matches!(frame.columns[0], ColumnValue::Null); + if let ColumnValue::Value(val) = frame.columns[1] { + assert_eq!(val, b"col2_value"); + } else { + panic!("expected Value"); + } + if let ColumnValue::Value(val) = frame.columns[2] { + assert_eq!(val, b""); + } else { + panic!("expected Value"); + } + } + + #[test] + fn roundtrip_data_row() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = DataRowFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn invalid_tag() { + let data = + b"E\x00\x00\x00\x1C\x00\x03\xff\xff\xff\xff\x00\x00\x00\x0Acol2_value\x00\x00\x00\x00"; + let err = DataRowFrame::from_bytes(data).unwrap_err(); + matches!(err, DataRowError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length_short() { + let data = b"D\x00\x00\x00\x05"; + let err = DataRowFrame::from_bytes(data).unwrap_err(); + matches!(err, DataRowError::UnexpectedLength(_)); + } + + #[test] + fn unexpected_eof() { + let data = b"D\x00\x00\x00\x1C\x00\x03\xff\xff\xff\xff\x00\x00\x00\x0Acol2_valu"; // short by 1 + let err = DataRowFrame::from_bytes(data).unwrap_err(); + matches!(err, DataRowError::UnexpectedEof); + } + + #[test] + fn extra_data_after() { + let data = b"D\x00\x00\x00\x1C\x00\x03\xff\xff\xff\xff\x00\x00\x00\x0Acol2_value\x00\x00\x00\x00\x00"; + let err = DataRowFrame::from_bytes(data).unwrap_err(); + matches!(err, DataRowError::UnexpectedLength(_)); + } + + #[test] + fn invalid_column_length() { + let data = b"D\x00\x00\x00\x0E\x00\x01\xff\xff\xff\xfe"; // len=-2 + let err = DataRowFrame::from_bytes(data).unwrap_err(); + matches!(err, DataRowError::InvalidColumnLength(-2)); + } + + #[test] + fn empty_row() { + let frame = DataRowFrame { columns: vec![] }; + let bytes = frame.to_bytes().unwrap(); + let expected = b"D\x00\x00\x00\x06\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + let decoded = DataRowFrame::from_bytes(expected).unwrap(); + assert_eq!(decoded.columns.len(), 0); + } + + #[test] + fn single_empty_value() { + let frame = DataRowFrame { + columns: vec![ColumnValue::Value(&[])], + }; + + let bytes = frame.to_bytes().unwrap(); + let expected = b"D\x00\x00\x00\x0A\x00\x01\x00\x00\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + + let decoded = DataRowFrame::from_bytes(expected).unwrap(); + + if let [ColumnValue::Value(val)] = &decoded.columns[..] { + assert_eq!(*val, &[] as &[u8]); + } else { + panic!("expected empty Value"); + } + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/empty_query_response.rs b/pgdog/src/wire_protocol/backend/empty_query_response.rs new file mode 100644 index 00000000..a3085d2a --- /dev/null +++ b/pgdog/src/wire_protocol/backend/empty_query_response.rs @@ -0,0 +1,120 @@ +//! Module: wire_protocol::backend::empty_query_response +//! +//! Provides parsing and serialization for the EmptyQueryResponse message ('I') in the protocol. +//! +//! - `EmptyQueryResponseFrame`: represents the EmptyQueryResponse message. +//! - `EmptyQueryResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `EmptyQueryResponseFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct EmptyQueryResponseFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum EmptyQueryResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for EmptyQueryResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + EmptyQueryResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + EmptyQueryResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for EmptyQueryResponseError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for EmptyQueryResponseFrame { + type Error = EmptyQueryResponseError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(EmptyQueryResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'I' { + return Err(EmptyQueryResponseError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(EmptyQueryResponseError::UnexpectedLength(len)); + } + + Ok(EmptyQueryResponseFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"I\x00\x00\x00\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_empty_query_response() { + let frame = EmptyQueryResponseFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"I\x00\x00\x00\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_empty_query_response() { + let data = b"I\x00\x00\x00\x04"; + let frame = EmptyQueryResponseFrame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_empty_query_response() { + let original = EmptyQueryResponseFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = EmptyQueryResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x04"; + let err = EmptyQueryResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, EmptyQueryResponseError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"I\x00\x00\x00\x05"; + let err = EmptyQueryResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, EmptyQueryResponseError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/error_response.rs b/pgdog/src/wire_protocol/backend/error_response.rs new file mode 100644 index 00000000..2dda908e --- /dev/null +++ b/pgdog/src/wire_protocol/backend/error_response.rs @@ -0,0 +1,382 @@ +//! Module: wire_protocol::backend::error_response +//! +//! Provides parsing and serialization for the ErrorResponse message ('E') in the protocol. +//! +//! - `ErrorResponseFrame`: represents the ErrorResponse message with a list of error fields. +//! - `ErrorResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `ErrorResponseFrame`. + +use crate::wire_protocol::WireSerializable; +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorResponseFrame<'a> { + pub fields: Vec>, +} + +// ----------------------------------------------------------------------------- +// ----- Subproperties --------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorFieldCode { + SeverityLocalized, + SeverityNonLocalized, + SqlState, + Message, + Detail, + Hint, + Position, + InternalPosition, + InternalQuery, + Where, + SchemaName, + TableName, + ColumnName, + DataTypeName, + ConstraintName, + File, + Line, + Routine, + Unknown(char), +} + +impl ErrorFieldCode { + pub fn from_char(c: char) -> Self { + match c { + 'S' => Self::SeverityLocalized, + 'V' => Self::SeverityNonLocalized, + 'C' => Self::SqlState, + 'M' => Self::Message, + 'D' => Self::Detail, + 'H' => Self::Hint, + 'P' => Self::Position, + 'p' => Self::InternalPosition, + 'q' => Self::InternalQuery, + 'W' => Self::Where, + 's' => Self::SchemaName, + 't' => Self::TableName, + 'c' => Self::ColumnName, + 'd' => Self::DataTypeName, + 'n' => Self::ConstraintName, + 'F' => Self::File, + 'L' => Self::Line, + 'R' => Self::Routine, + other => Self::Unknown(other), + } + } + + pub fn to_char(&self) -> char { + match self { + Self::SeverityLocalized => 'S', + Self::SeverityNonLocalized => 'V', + Self::SqlState => 'C', + Self::Message => 'M', + Self::Detail => 'D', + Self::Hint => 'H', + Self::Position => 'P', + Self::InternalPosition => 'p', + Self::InternalQuery => 'q', + Self::Where => 'W', + Self::SchemaName => 's', + Self::TableName => 't', + Self::ColumnName => 'c', + Self::DataTypeName => 'd', + Self::ConstraintName => 'n', + Self::File => 'F', + Self::Line => 'L', + Self::Routine => 'R', + Self::Unknown(c) => *c, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorField<'a> { + pub code: ErrorFieldCode, + pub value: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ErrorResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidFieldCode(u8), +} + +impl fmt::Display for ErrorResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ErrorResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + ErrorResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + ErrorResponseError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + ErrorResponseError::UnexpectedEof => write!(f, "unexpected EOF"), + ErrorResponseError::InvalidFieldCode(c) => write!(f, "invalid field code: {c:#X}"), + } + } +} + +impl StdError for ErrorResponseError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + if let ErrorResponseError::Utf8Error(e) = self { + Some(e) + } else { + None + } + } +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ErrorResponseFrame<'a> { + type Error = ErrorResponseError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + // Need at least tag + length + if bytes.len() < 5 { + return Err(ErrorResponseError::UnexpectedLength(bytes.len() as u32)); + } + // Tag check + if bytes[0] != b'E' { + return Err(ErrorResponseError::UnexpectedTag(bytes[0])); + } + // Read length field + let len_field = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + // The rest is payload + let payload = &bytes[5..]; + let payload_len = payload.len(); + + // Parse fields + let mut offset = 0; + let mut fields = Vec::new(); + let mut seen_terminator = false; + + while offset < payload_len { + let code_byte = payload[offset]; + offset += 1; + + // final terminator + if code_byte == 0 { + seen_terminator = true; + break; + } + let code_char = code_byte as char; + if !code_char.is_ascii() { + return Err(ErrorResponseError::InvalidFieldCode(code_byte)); + } + + // find the NUL ending this field + let rest = &payload[offset..]; + let pos = rest + .iter() + .position(|&b| b == 0) + .ok_or(ErrorResponseError::UnexpectedEof)?; + let raw = &rest[..pos]; + let value = str::from_utf8(raw).map_err(ErrorResponseError::Utf8Error)?; + + // advance past the field data + its NUL + offset += pos + 1; + + fields.push(ErrorField { + code: ErrorFieldCode::from_char(code_char), + value, + }); + } + + if !seen_terminator { + // we ran out of bytes without hitting the 0 + return Err(ErrorResponseError::UnexpectedEof); + } + + // No extra bytes allowed after the terminator + if offset != payload_len { + return Err(ErrorResponseError::UnexpectedLength(len_field)); + } + + // Tests expect len_field == payload_len + 1 + if (len_field as usize) != payload_len + 1 { + return Err(ErrorResponseError::UnexpectedLength(len_field)); + } + + Ok(ErrorResponseFrame { fields }) + } + + fn to_bytes(&self) -> Result { + // build payload: each field + its NUL, then final 0 + let mut body = BytesMut::with_capacity(self.body_size()); + for field in &self.fields { + body.put_u8(field.code.to_char() as u8); + body.extend_from_slice(field.value.as_bytes()); + body.put_u8(0); + } + // final terminator + body.put_u8(0); + + // length = payload.len() + 1 (per your tests) + let len_field = (body.len() + 1) as u32; + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'E'); + frame.put_u32(len_field); + frame.extend_from_slice(&body); + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + // sum of (code + value + NUL) for each field, plus one final terminator + 1 + self + .fields + .iter() + .map(|f| 1 + f.value.len() + 1) + .sum::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_simple_error<'a>() -> ErrorResponseFrame<'a> { + ErrorResponseFrame { + fields: vec![ + ErrorField { + code: ErrorFieldCode::SeverityLocalized, + value: "ERROR", + }, + ErrorField { + code: ErrorFieldCode::Message, + value: "permission denied", + }, + ], + } + } + + fn make_detailed_error<'a>() -> ErrorResponseFrame<'a> { + ErrorResponseFrame { + fields: vec![ + ErrorField { + code: ErrorFieldCode::SeverityNonLocalized, + value: "ERROR", + }, + ErrorField { + code: ErrorFieldCode::SqlState, + value: "42501", + }, + ErrorField { + code: ErrorFieldCode::Message, + value: "permission denied for table test", + }, + ErrorField { + code: ErrorFieldCode::Detail, + value: "some detail", + }, + ErrorField { + code: ErrorFieldCode::Hint, + value: "grant permission", + }, + ], + } + } + + #[test] + fn serialize_simple() { + let frame = make_simple_error(); + let bytes = frame.to_bytes().unwrap(); + // length = payload.len() + 1 = 28 = 0x1C + let expected = b"E\x00\x00\x00\x1CSERROR\x00Mpermission denied\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_simple() { + let data = b"E\x00\x00\x00\x1CSERROR\x00Mpermission denied\x00\x00"; + let frame = ErrorResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.fields.len(), 2); + assert_eq!(frame.fields[0].code, ErrorFieldCode::SeverityLocalized); + assert_eq!(frame.fields[0].value, "ERROR"); + assert_eq!(frame.fields[1].code, ErrorFieldCode::Message); + assert_eq!(frame.fields[1].value, "permission denied"); + } + + #[test] + fn roundtrip_simple() { + let original = make_simple_error(); + let bytes = original.to_bytes().unwrap(); + let decoded = ErrorResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.fields, original.fields); + } + + #[test] + fn roundtrip_detailed() { + let original = make_detailed_error(); + let bytes = original.to_bytes().unwrap(); + let decoded = ErrorResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.fields, original.fields); + } + + #[test] + fn unknown_field() { + // length = payload.len() (10) + 1 = 11 = 0x0B + let data = b"E\x00\x00\x00\x0BXunknown\x00\x00"; + let frame = ErrorResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.fields.len(), 1); + assert!(matches!(frame.fields[0].code, ErrorFieldCode::Unknown('X'))); + assert_eq!(frame.fields[0].value, "unknown"); + } + + #[test] + fn invalid_tag() { + let data = b"N\x00\x00\x00\x1CSERROR\x00Mpermission denied\x00\x00"; + let err = ErrorResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, ErrorResponseError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"E\x00\x00\x00\x1DSERROR\x00Mpermission denied\x00\x00"; + let err = ErrorResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, ErrorResponseError::UnexpectedLength(_)); + } + + #[test] + fn missing_terminator() { + let data = b"E\x00\x00\x00\x1CSERROR\x00Mpermission denied\x00"; + let err = ErrorResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, ErrorResponseError::UnexpectedEof); + } + + #[test] + fn extra_after_terminator() { + let data = b"E\x00\x00\x00\x1CSERROR\x00Mpermission denied\x00\x00\x00"; + let err = ErrorResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, ErrorResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_simple_error().to_bytes().unwrap().to_vec(); + bytes[6] = 0xFF; // corrupt in the value portion + let err = ErrorResponseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, ErrorResponseError::Utf8Error(_)); + } + + #[test] + fn invalid_field_code() { + let data = b"E\x00\x00\x00\x0D\xFFvalue\x00\x00"; + let err = ErrorResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, ErrorResponseError::InvalidFieldCode(0xFF)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/function_call_response.rs b/pgdog/src/wire_protocol/backend/function_call_response.rs new file mode 100644 index 00000000..33bea4f2 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/function_call_response.rs @@ -0,0 +1,216 @@ +//! Module: wire_protocol::backend::function_call_response +//! +//! Provides parsing and serialization for the FunctionCallResponse message ('V') in the protocol. +//! +//! - `FunctionCallResponseFrame`: represents the FunctionCallResponse message with the optional result value. +//! - `FunctionCallResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `FunctionCallResponseFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionCallResponseFrame<'a> { + pub result: Option<&'a [u8]>, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum FunctionCallResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + InvalidValueLength, +} + +impl fmt::Display for FunctionCallResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionCallResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + FunctionCallResponseError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + FunctionCallResponseError::InvalidValueLength => write!(f, "invalid value length"), + } + } +} + +impl StdError for FunctionCallResponseError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for FunctionCallResponseFrame<'a> { + type Error = FunctionCallResponseError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 9 { + return Err(FunctionCallResponseError::UnexpectedLength( + bytes.len() as u32 + )); + } + + let tag = bytes[0]; + if tag != b'V' { + return Err(FunctionCallResponseError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(FunctionCallResponseError::UnexpectedLength(len)); + } + + let val_len_i32 = i32::from_be_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]); + + let result = if val_len_i32 == -1 { + if bytes.len() != 9 { + return Err(FunctionCallResponseError::UnexpectedLength(len)); + } + None + } else if val_len_i32 < 0 { + return Err(FunctionCallResponseError::InvalidValueLength); + } else { + let val_len = val_len_i32 as usize; + if bytes.len() != 9 + val_len { + return Err(FunctionCallResponseError::UnexpectedLength(len)); + } + Some(&bytes[9..9 + val_len]) + }; + + Ok(FunctionCallResponseFrame { result }) + } + + fn to_bytes(&self) -> Result { + let val_len_i32 = self.result.map_or(-1i32, |r| r.len() as i32); + let value_size = if val_len_i32 >= 0 { + val_len_i32 as usize + } else { + 0 + }; + let contents_size = 4 + value_size; // val_len + value + let length = 4 + contents_size; // length field + contents + + let mut buf = BytesMut::with_capacity(1 + length); + buf.put_u8(b'V'); + buf.put_u32(length as u32); + buf.put_i32(val_len_i32); + if let Some(result) = self.result { + buf.extend_from_slice(result); + } + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self.result.map_or(0, |r| r.len()) + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame_with_result<'a>() -> FunctionCallResponseFrame<'a> { + FunctionCallResponseFrame { + result: Some(b"result_value"), + } + } + + fn make_frame_null() -> FunctionCallResponseFrame<'static> { + FunctionCallResponseFrame { result: None } + } + + #[test] + fn serialize_with_result() { + let frame = make_frame_with_result(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"V\x00\x00\x00\x14\x00\x00\x00\x0Cresult_value"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_with_result() { + let data = b"V\x00\x00\x00\x14\x00\x00\x00\x0Cresult_value"; + let frame = FunctionCallResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.result, Some(b"result_value".as_ref())); + } + + #[test] + fn roundtrip_with_result() { + let original = make_frame_with_result(); + let bytes = original.to_bytes().unwrap(); + let decoded = FunctionCallResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.result, original.result); + } + + #[test] + fn serialize_null() { + let frame = make_frame_null(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"V\x00\x00\x00\x08\xff\xff\xff\xff"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_null() { + let data = b"V\x00\x00\x00\x08\xff\xff\xff\xff"; + let frame = FunctionCallResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.result, None); + } + + #[test] + fn roundtrip_null() { + let original = make_frame_null(); + let bytes = original.to_bytes().unwrap(); + let decoded = FunctionCallResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.result, original.result); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x08\xff\xff\xff\xff"; + let err = FunctionCallResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, FunctionCallResponseError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length_null() { + let data = b"V\x00\x00\x00\x09\xff\xff\xff\xff"; + let err = FunctionCallResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, FunctionCallResponseError::UnexpectedLength(_)); + } + + #[test] + fn extra_bytes_null() { + let data = b"V\x00\x00\x00\x08\xff\xff\xff\xff\x00"; + let err = FunctionCallResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, FunctionCallResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_value_length_negative() { + let data = b"V\x00\x00\x00\x08\xff\xff\xff\xfe"; + let err = FunctionCallResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, FunctionCallResponseError::InvalidValueLength); + } + + #[test] + fn short_value() { + let data = b"V\x00\x00\x00\x14\x00\x00\x00\x0Cresult_valu"; // one byte short + let err = FunctionCallResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, FunctionCallResponseError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/mod.rs b/pgdog/src/wire_protocol/backend/mod.rs new file mode 100644 index 00000000..e0fffe6e --- /dev/null +++ b/pgdog/src/wire_protocol/backend/mod.rs @@ -0,0 +1,176 @@ +pub mod authentication_cleartext_password; +pub mod authentication_gss; +pub mod authentication_gss_continue; +pub mod authentication_kerberos_v5; +pub mod authentication_md5_password; +pub mod authentication_ok; +pub mod authentication_sasl; +pub mod authentication_sasl_continue; +pub mod authentication_sasl_final; +pub mod authentication_scm_credential; +pub mod authentication_sspi; +pub mod backend_key_data; +pub mod bind_complete; +pub mod close_complete; +pub mod command_complete; +pub mod copy_both_response; +pub mod copy_data; +pub mod copy_done; +pub mod copy_in_response; +pub mod copy_out_response; +pub mod data_row; +pub mod empty_query_response; +pub mod error_response; +pub mod function_call_response; +pub mod negotiate_protocol_version; +pub mod no_data; +pub mod notice_response; +pub mod notification_response; +pub mod parameter_description; +pub mod parameter_status; +pub mod parse_complete; +pub mod portal_suspended; +pub mod ready_for_query; +pub mod row_description; + +use crate::wire_protocol::backend::authentication_cleartext_password::AuthenticationCleartextPasswordFrame; +use crate::wire_protocol::backend::authentication_gss::AuthenticationGssFrame; +use crate::wire_protocol::backend::authentication_gss_continue::AuthenticationGssContinueFrame; +use crate::wire_protocol::backend::authentication_kerberos_v5::AuthenticationKerberosV5Frame; +use crate::wire_protocol::backend::authentication_md5_password::AuthenticationMd5PasswordFrame; +use crate::wire_protocol::backend::authentication_ok::AuthenticationOkFrame; +use crate::wire_protocol::backend::authentication_sasl::AuthenticationSaslFrame; +use crate::wire_protocol::backend::authentication_sasl_continue::AuthenticationSaslContinueFrame; +use crate::wire_protocol::backend::authentication_sasl_final::AuthenticationSaslFinalFrame; +use crate::wire_protocol::backend::authentication_scm_credential::AuthenticationScmCredentialFrame; +use crate::wire_protocol::backend::authentication_sspi::AuthenticationSspiFrame; +use crate::wire_protocol::backend::backend_key_data::BackendKeyDataFrame; +use crate::wire_protocol::backend::bind_complete::BindCompleteFrame; +use crate::wire_protocol::backend::close_complete::CloseCompleteFrame; +use crate::wire_protocol::backend::command_complete::CommandCompleteFrame; +use crate::wire_protocol::backend::copy_both_response::CopyBothResponseFrame; +use crate::wire_protocol::backend::copy_data::CopyDataFrame; +use crate::wire_protocol::backend::copy_done::CopyDoneFrame; +use crate::wire_protocol::backend::copy_in_response::CopyInResponseFrame; +use crate::wire_protocol::backend::copy_out_response::CopyOutResponseFrame; +use crate::wire_protocol::backend::data_row::DataRowFrame; +use crate::wire_protocol::backend::empty_query_response::EmptyQueryResponseFrame; +use crate::wire_protocol::backend::error_response::ErrorResponseFrame; +use crate::wire_protocol::backend::function_call_response::FunctionCallResponseFrame; +use crate::wire_protocol::backend::negotiate_protocol_version::NegotiateProtocolVersionFrame; +use crate::wire_protocol::backend::no_data::NoDataFrame; +use crate::wire_protocol::backend::notice_response::NoticeResponseFrame; +use crate::wire_protocol::backend::notification_response::NotificationResponseFrame; +use crate::wire_protocol::backend::parameter_description::ParameterDescriptionFrame; +use crate::wire_protocol::backend::parameter_status::ParameterStatusFrame; +use crate::wire_protocol::backend::parse_complete::ParseCompleteFrame; +use crate::wire_protocol::backend::portal_suspended::PortalSuspendedFrame; +use crate::wire_protocol::backend::ready_for_query::ReadyForQueryFrame; +use crate::wire_protocol::backend::row_description::RowDescriptionFrame; + +/// Represents any backend-initiated protocol message. +/// Bidirectional protocol messages are also included. +#[derive(Debug)] +pub enum BackendProtocolMessage<'a> { + /// AuthenticationCleartextPassword message + AuthenticationCleartextPassword(AuthenticationCleartextPasswordFrame), + + /// AuthenticationGss message + AuthenticationGss(AuthenticationGssFrame), + + /// AuthenticationGssContinue message + AuthenticationGssContinue(AuthenticationGssContinueFrame<'a>), + + /// AuthenticationKerberosV5 message + AuthenticationKerberosV5(AuthenticationKerberosV5Frame), + + /// AuthenticationMd5Password message + AuthenticationMd5Password(AuthenticationMd5PasswordFrame), + + /// AuthenticationOk message + AuthenticationOk(AuthenticationOkFrame), + + /// AuthenticationSasl message + AuthenticationSasl(AuthenticationSaslFrame), + + /// AuthenticationSaslContinue message + AuthenticationSaslContinue(AuthenticationSaslContinueFrame<'a>), + + /// AuthenticationSaslFinal message + AuthenticationSaslFinal(AuthenticationSaslFinalFrame<'a>), + + /// AuthenticationScmCredential message + AuthenticationScmCredential(AuthenticationScmCredentialFrame), + + /// AuthenticationSspi message + AuthenticationSspi(AuthenticationSspiFrame), + + /// BackendKeyData message + BackendKeyData(BackendKeyDataFrame), + + /// BindComplete message + BindComplete(BindCompleteFrame), + + /// CloseComplete message + CloseComplete(CloseCompleteFrame), + + /// CommandComplete message + CommandComplete(CommandCompleteFrame<'a>), + + /// CopyBothResponse message + CopyBothResponse(CopyBothResponseFrame), + + /// CopyData message for COPY operations + CopyData(CopyDataFrame<'a>), + + /// CopyDone message for COPY operations + CopyDone(CopyDoneFrame), + + /// CopyInResponse message + CopyInResponse(CopyInResponseFrame), + + /// CopyOutResponse message + CopyOutResponse(CopyOutResponseFrame), + + /// DataRow message + DataRow(DataRowFrame<'a>), + + /// EmptyQueryResponse message + EmptyQueryResponse(EmptyQueryResponseFrame), + + /// ErrorResponse message + ErrorResponse(ErrorResponseFrame<'a>), + + /// FunctionCallResponse message + FunctionCallResponse(FunctionCallResponseFrame<'a>), + + /// NegotiateProtocolVersion message + NegotiateProtocolVersion(NegotiateProtocolVersionFrame<'a>), + + /// NoData message + NoData(NoDataFrame), + + /// NoticeResponse message + NoticeResponse(NoticeResponseFrame<'a>), + + /// NotificationResponse message + NotificationResponse(NotificationResponseFrame<'a>), + + /// ParameterDescription message + ParameterDescription(ParameterDescriptionFrame), + + /// ParameterStatus message + ParameterStatus(ParameterStatusFrame<'a>), + + /// ParseComplete message + ParseComplete(ParseCompleteFrame), + + /// PortalSuspended message + PortalSuspended(PortalSuspendedFrame), + + /// ReadyForQuery message + ReadyForQuery(ReadyForQueryFrame), + + /// RowDescription message + RowDescription(RowDescriptionFrame<'a>), +} diff --git a/pgdog/src/wire_protocol/backend/negotiate_protocol_version.rs b/pgdog/src/wire_protocol/backend/negotiate_protocol_version.rs new file mode 100644 index 00000000..b255b3a4 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/negotiate_protocol_version.rs @@ -0,0 +1,262 @@ +//! Module: wire_protocol::backend::negotiate_protocol_version +//! +//! Provides parsing and serialization for the NegotiateProtocolVersion message ('v') in the protocol. +//! +//! - `NegotiateProtocolVersionFrame`: represents the NegotiateProtocolVersion message with supported minor version and unrecognized options. +//! - `NegotiateProtocolVersionError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `NegotiateProtocolVersionFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; + +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NegotiateProtocolVersionFrame<'a> { + pub newest_minor_version: i32, + pub unrecognized_options: Vec<&'a str>, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum NegotiateProtocolVersionError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedEof, + Utf8Error(str::Utf8Error), +} + +impl fmt::Display for NegotiateProtocolVersionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NegotiateProtocolVersionError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + NegotiateProtocolVersionError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + NegotiateProtocolVersionError::UnexpectedEof => write!(f, "unexpected EOF"), + NegotiateProtocolVersionError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + } + } +} + +impl StdError for NegotiateProtocolVersionError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + NegotiateProtocolVersionError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, NegotiateProtocolVersionError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(NegotiateProtocolVersionError::UnexpectedEof)?; + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + Ok(str::from_utf8(raw).map_err(NegotiateProtocolVersionError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for NegotiateProtocolVersionFrame<'a> { + type Error = NegotiateProtocolVersionError; + + fn from_bytes(bytes_full: &'a [u8]) -> Result { + // need at least tag+len + if bytes_full.len() < 5 { + return Err(NegotiateProtocolVersionError::UnexpectedEof); + } + if bytes_full[0] != b'v' { + return Err(NegotiateProtocolVersionError::UnexpectedTag(bytes_full[0])); + } + + let len = u32::from_be_bytes([bytes_full[1], bytes_full[2], bytes_full[3], bytes_full[4]]); + // body must contain at least 8 bytes (minor version + count) + if len < 8 { + return Err(NegotiateProtocolVersionError::UnexpectedLength(len)); + } + + let total_len = (len as usize) + 1; // tag + length bytes + if bytes_full.len() < total_len { + return Err(NegotiateProtocolVersionError::UnexpectedEof); + } + if bytes_full.len() > total_len { + return Err(NegotiateProtocolVersionError::UnexpectedLength(len)); + } + + let mut payload = &bytes_full[5..total_len]; + // read minor version + num_options + if payload.len() < 8 { + return Err(NegotiateProtocolVersionError::UnexpectedEof); + } + let newest_minor_version = i32::from_be_bytes(payload[0..4].try_into().unwrap()); + let num_options = i32::from_be_bytes(payload[4..8].try_into().unwrap()); + if num_options < 0 { + return Err(NegotiateProtocolVersionError::UnexpectedLength(len)); + } + + payload = &payload[8..]; + let mut unrecognized_options = Vec::with_capacity(num_options as usize); + for _ in 0..num_options { + let opt = read_cstr(&mut payload)?; + unrecognized_options.push(opt); + } + + if !payload.is_empty() { + return Err(NegotiateProtocolVersionError::UnexpectedLength(len)); + } + + Ok(NegotiateProtocolVersionFrame { + newest_minor_version, + unrecognized_options, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + body.put_i32(self.newest_minor_version); + body.put_i32(self.unrecognized_options.len() as i32); + for opt in &self.unrecognized_options { + body.extend_from_slice(opt.as_bytes()); + body.put_u8(0); + } + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'v'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 4 + 4 + + self + .unrecognized_options + .iter() + .map(|opt| opt.len() + 1) + .sum::() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame_empty<'a>() -> NegotiateProtocolVersionFrame<'a> { + NegotiateProtocolVersionFrame { + newest_minor_version: 0, + unrecognized_options: vec![], + } + } + + fn make_frame_with_options<'a>() -> NegotiateProtocolVersionFrame<'a> { + NegotiateProtocolVersionFrame { + newest_minor_version: 123, + unrecognized_options: vec!["opt1", "opt2"], + } + } + + #[test] + fn serialize_empty() { + let frame = make_frame_empty(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"v\x00\x00\x00\x0C\x00\x00\x00\x00\x00\x00\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn serialize_with_options() { + let frame = make_frame_with_options(); + let bytes = frame.to_bytes().unwrap(); + // len = 4 + 4+4 + (4+1)+(4+1) = 4+8+10=22 + let expected = b"v\x00\x00\x00\x16\x00\x00\x00\x7B\x00\x00\x00\x02opt1\x00opt2\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_empty() { + let data = b"v\x00\x00\x00\x0C\x00\x00\x00\x00\x00\x00\x00\x00"; + let frame = NegotiateProtocolVersionFrame::from_bytes(data).unwrap(); + assert_eq!(frame.newest_minor_version, 0); + assert_eq!(frame.unrecognized_options.len(), 0); + } + + #[test] + fn deserialize_with_options() { + let data = b"v\x00\x00\x00\x16\x00\x00\x00\x7B\x00\x00\x00\x02opt1\x00opt2\x00"; + let frame = NegotiateProtocolVersionFrame::from_bytes(data).unwrap(); + assert_eq!(frame.newest_minor_version, 123); + assert_eq!(frame.unrecognized_options, vec!["opt1", "opt2"]); + } + + #[test] + fn roundtrip_empty() { + let original = make_frame_empty(); + let bytes = original.to_bytes().unwrap(); + let decoded = NegotiateProtocolVersionFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn roundtrip_with_options() { + let original = make_frame_with_options(); + let bytes = original.to_bytes().unwrap(); + let decoded = NegotiateProtocolVersionFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn invalid_tag() { + let data = b"V\x00\x00\x00\x0C\x00\x00\x00\x00\x00\x00\x00\x00"; + let err = NegotiateProtocolVersionFrame::from_bytes(data).unwrap_err(); + matches!(err, NegotiateProtocolVersionError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length_short() { + let data = b"v\x00\x00\x00\x0B\x00\x00\x00\x00\x00\x00\x00\x00"; + let err = NegotiateProtocolVersionFrame::from_bytes(data).unwrap_err(); + matches!(err, NegotiateProtocolVersionError::UnexpectedLength(_)); + } + + #[test] + fn unexpected_eof_option() { + let data = b"v\x00\x00\x00\x11\x00\x00\x00\x00\x00\x00\x00\x01opt"; // no null + let err = NegotiateProtocolVersionFrame::from_bytes(data).unwrap_err(); + matches!(err, NegotiateProtocolVersionError::UnexpectedEof); + } + + #[test] + fn extra_data_after() { + let data = b"v\x00\x00\x00\x0C\x00\x00\x00\x00\x00\x00\x00\x00\x00"; + let err = NegotiateProtocolVersionFrame::from_bytes(data).unwrap_err(); + matches!(err, NegotiateProtocolVersionError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let data = vec![b'v', 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 1, 0xFF, 0]; // invalid UTF8 + let err = NegotiateProtocolVersionFrame::from_bytes(&data).unwrap_err(); + matches!(err, NegotiateProtocolVersionError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/no_data.rs b/pgdog/src/wire_protocol/backend/no_data.rs new file mode 100644 index 00000000..69d120d0 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/no_data.rs @@ -0,0 +1,120 @@ +//! Module: wire_protocol::backend::no_data +//! +//! Provides parsing and serialization for the NoData message ('n') in the protocol. +//! +//! - `NoDataFrame`: represents the NoData message. +//! - `NoDataError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `NoDataFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoDataFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum NoDataError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for NoDataError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NoDataError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + NoDataError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for NoDataError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for NoDataFrame { + type Error = NoDataError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(NoDataError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'n' { + return Err(NoDataError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(NoDataError::UnexpectedLength(len)); + } + + Ok(NoDataFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"n\x00\x00\x00\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_no_data() { + let frame = NoDataFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"n\x00\x00\x00\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_no_data() { + let data = b"n\x00\x00\x00\x04"; + let frame = NoDataFrame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_no_data() { + let original = NoDataFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = NoDataFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x04"; + let err = NoDataFrame::from_bytes(data).unwrap_err(); + matches!(err, NoDataError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"n\x00\x00\x00\x05"; + let err = NoDataFrame::from_bytes(data).unwrap_err(); + matches!(err, NoDataError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/notice_response.rs b/pgdog/src/wire_protocol/backend/notice_response.rs new file mode 100644 index 00000000..89263c69 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/notice_response.rs @@ -0,0 +1,375 @@ +//! Module: wire_protocol::backend::notice_response +//! +//! Provides parsing and serialization for the NoticeResponse message ('N') in the protocol. +//! +//! - `NoticeResponseFrame`: represents the NoticeResponse message with a list of notice fields. +//! - `NoticeResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `NoticeResponseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NoticeResponseFrame<'a> { + pub fields: Vec>, +} + +// ----------------------------------------------------------------------------- +// ----- Subproperties --------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NoticeFieldCode { + SeverityLocalized, + SeverityNonLocalized, + SqlState, + Message, + Detail, + Hint, + Position, + InternalPosition, + InternalQuery, + Where, + SchemaName, + TableName, + ColumnName, + DataTypeName, + ConstraintName, + File, + Line, + Routine, + Unknown(char), +} + +impl NoticeFieldCode { + pub fn from_char(c: char) -> Self { + match c { + 'S' => Self::SeverityLocalized, + 'V' => Self::SeverityNonLocalized, + 'C' => Self::SqlState, + 'M' => Self::Message, + 'D' => Self::Detail, + 'H' => Self::Hint, + 'P' => Self::Position, + 'p' => Self::InternalPosition, + 'q' => Self::InternalQuery, + 'W' => Self::Where, + 's' => Self::SchemaName, + 't' => Self::TableName, + 'c' => Self::ColumnName, + 'd' => Self::DataTypeName, + 'n' => Self::ConstraintName, + 'F' => Self::File, + 'L' => Self::Line, + 'R' => Self::Routine, + other => Self::Unknown(other), + } + } + + pub fn to_char(&self) -> char { + match self { + Self::SeverityLocalized => 'S', + Self::SeverityNonLocalized => 'V', + Self::SqlState => 'C', + Self::Message => 'M', + Self::Detail => 'D', + Self::Hint => 'H', + Self::Position => 'P', + Self::InternalPosition => 'p', + Self::InternalQuery => 'q', + Self::Where => 'W', + Self::SchemaName => 's', + Self::TableName => 't', + Self::ColumnName => 'c', + Self::DataTypeName => 'd', + Self::ConstraintName => 'n', + Self::File => 'F', + Self::Line => 'L', + Self::Routine => 'R', + Self::Unknown(c) => *c, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NoticeField<'a> { + pub code: NoticeFieldCode, + pub value: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum NoticeResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidFieldCode(u8), +} + +impl fmt::Display for NoticeResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NoticeResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + NoticeResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + NoticeResponseError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + NoticeResponseError::UnexpectedEof => write!(f, "unexpected EOF"), + NoticeResponseError::InvalidFieldCode(c) => write!(f, "invalid field code: {c:#X}"), + } + } +} + +impl StdError for NoticeResponseError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + NoticeResponseError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, NoticeResponseError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(NoticeResponseError::UnexpectedEof)?; + + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + + Ok(str::from_utf8(raw).map_err(NoticeResponseError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for NoticeResponseFrame<'a> { + type Error = NoticeResponseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(NoticeResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes.get_u8(); + if tag != b'N' { + return Err(NoticeResponseError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + let payload_len = bytes.remaining(); + + let mut fields = Vec::new(); + loop { + if bytes.remaining() == 0 { + return Err(NoticeResponseError::UnexpectedEof); + } + let code = bytes.get_u8(); + if code == 0 { + break; + } + let c = char::from(code); + if !c.is_ascii() { + return Err(NoticeResponseError::InvalidFieldCode(code)); + } + let val = read_cstr(&mut bytes)?; + fields.push(NoticeField { + code: NoticeFieldCode::from_char(c), + value: val, + }); + } + + // no extra bytes allowed + if bytes.remaining() != 0 { + return Err(NoticeResponseError::UnexpectedLength(len)); + } + // tests expect len == payload_len + 1 + if (len as usize) != payload_len + 1 { + return Err(NoticeResponseError::UnexpectedLength(len)); + } + + Ok(NoticeResponseFrame { fields }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + for f in &self.fields { + body.put_u8(f.code.to_char() as u8); + body.extend_from_slice(f.value.as_bytes()); + body.put_u8(0); + } + body.put_u8(0); + // use +1, not +4 + let len_field = (body.len() + 1) as u32; + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'N'); + frame.put_u32(len_field); + frame.extend_from_slice(&body); + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + let mut size = 1; // terminator + for field in &self.fields { + size += 1 + field.value.len() + 1; // code + value + nul + } + size + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_simple_notice<'a>() -> NoticeResponseFrame<'a> { + NoticeResponseFrame { + fields: vec![ + NoticeField { + code: NoticeFieldCode::SeverityLocalized, + value: "NOTICE", + }, + NoticeField { + code: NoticeFieldCode::Message, + value: "some notice", + }, + ], + } + } + + fn make_detailed_notice<'a>() -> NoticeResponseFrame<'a> { + NoticeResponseFrame { + fields: vec![ + NoticeField { + code: NoticeFieldCode::SeverityNonLocalized, + value: "NOTICE", + }, + NoticeField { + code: NoticeFieldCode::SqlState, + value: "00000", + }, + NoticeField { + code: NoticeFieldCode::Message, + value: "some notice message", + }, + NoticeField { + code: NoticeFieldCode::Detail, + value: "some detail", + }, + NoticeField { + code: NoticeFieldCode::Hint, + value: "some hint", + }, + ], + } + } + + #[test] + fn serialize_simple() { + let frame = make_simple_notice(); + let bytes = frame.to_bytes().unwrap(); + // payload = 8 + 13 + 1 = 22; len_field = 22 + 1 = 23 = 0x17 + let expected = b"N\x00\x00\x00\x17SNOTICE\x00Msome notice\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_simple() { + let data = b"N\x00\x00\x00\x17SNOTICE\x00Msome notice\x00\x00"; + let frame = NoticeResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.fields.len(), 2); + assert_eq!(frame.fields[0].code, NoticeFieldCode::SeverityLocalized); + assert_eq!(frame.fields[0].value, "NOTICE"); + assert_eq!(frame.fields[1].code, NoticeFieldCode::Message); + assert_eq!(frame.fields[1].value, "some notice"); + } + + #[test] + fn roundtrip_simple() { + let original = make_simple_notice(); + let bytes = original.to_bytes().unwrap(); + let decoded = NoticeResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.fields, original.fields); + } + + #[test] + fn roundtrip_detailed() { + let original = make_detailed_notice(); + let bytes = original.to_bytes().unwrap(); + let decoded = NoticeResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.fields, original.fields); + } + + #[test] + fn unknown_field() { + // payload = 1 + 7 + 1 + 1 = 10; len_field = 10 + 1 = 11 = 0x0B + let data = b"N\x00\x00\x00\x0BXunknown\x00\x00"; + let frame = NoticeResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.fields.len(), 1); + assert!(matches!( + frame.fields[0].code, + NoticeFieldCode::Unknown('X') + )); + assert_eq!(frame.fields[0].value, "unknown"); + } + + #[test] + fn invalid_length() { + // using any wrong length to trigger error + let data = b"N\x00\x00\x00\x18SNOTICE\x00Msome notice\x00\x00"; + let err = NoticeResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NoticeResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_tag() { + let data = b"E\x00\x00\x00\x19SNOTICE\x00Msome notice\x00\x00"; + let err = NoticeResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NoticeResponseError::UnexpectedTag(_)); + } + + #[test] + fn missing_terminator() { + let data = b"N\x00\x00\x00\x18SNOTICE\x00Msome notice\x00"; + let err = NoticeResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NoticeResponseError::UnexpectedEof); + } + + #[test] + fn extra_after_terminator() { + let data = b"N\x00\x00\x00\x19SNOTICE\x00Msome notice\x00\x00\x00"; + let err = NoticeResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NoticeResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_simple_notice().to_bytes().unwrap().to_vec(); + bytes[6] = 0xFF; // corrupt in value + let err = NoticeResponseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, NoticeResponseError::Utf8Error(_)); + } + + #[test] + fn invalid_field_code() { + let data = b"N\x00\x00\x00\x0D\xffvalue\x00\x00"; + let err = NoticeResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NoticeResponseError::InvalidFieldCode(0xFF)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/notification_response.rs b/pgdog/src/wire_protocol/backend/notification_response.rs new file mode 100644 index 00000000..a1d213ce --- /dev/null +++ b/pgdog/src/wire_protocol/backend/notification_response.rs @@ -0,0 +1,248 @@ +//! Module: wire_protocol::backend::notification_response +//! +//! Provides parsing and serialization for the NotificationResponse message ('A') in the protocol. +//! +//! - `NotificationResponseFrame`: represents the NotificationResponse message with PID, channel, and payload. +//! - `NotificationResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `NotificationResponseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct NotificationResponseFrame<'a> { + pub pid: i32, + pub channel: &'a str, + pub payload: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum NotificationResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, +} + +impl fmt::Display for NotificationResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NotificationResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + NotificationResponseError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + NotificationResponseError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + NotificationResponseError::UnexpectedEof => write!(f, "unexpected EOF"), + } + } +} + +impl StdError for NotificationResponseError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + NotificationResponseError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, NotificationResponseError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(NotificationResponseError::UnexpectedEof)?; + + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + + Ok(str::from_utf8(raw).map_err(NotificationResponseError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for NotificationResponseFrame<'a> { + type Error = NotificationResponseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(NotificationResponseError::UnexpectedLength( + bytes.len() as u32 + )); + } + + let tag = bytes.get_u8(); + if tag != b'A' { + return Err(NotificationResponseError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len as usize != bytes.remaining() + 4 { + return Err(NotificationResponseError::UnexpectedLength(len)); + } + + let pid = bytes.get_i32(); + + let channel = read_cstr(&mut bytes)?; + + let payload = read_cstr(&mut bytes)?; + + if bytes.remaining() != 0 { + return Err(NotificationResponseError::UnexpectedLength(len)); + } + + Ok(NotificationResponseFrame { + pid, + channel, + payload, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + + body.put_i32(self.pid); + body.extend_from_slice(self.channel.as_bytes()); + body.put_u8(0); + body.extend_from_slice(self.payload.as_bytes()); + body.put_u8(0); + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'A'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self.channel.len() + 1 + self.payload.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> NotificationResponseFrame<'a> { + NotificationResponseFrame { + pid: 1234, + channel: "test_channel", + payload: "test_payload", + } + } + + fn make_empty_payload_frame<'a>() -> NotificationResponseFrame<'a> { + NotificationResponseFrame { + pid: 5678, + channel: "empty", + payload: "", + } + } + + #[test] + fn serialize_notification() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"A\x00\x00\x00\x22\x00\x00\x04\xD2test_channel\x00test_payload\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_notification() { + let data = b"A\x00\x00\x00\x22\x00\x00\x04\xD2test_channel\x00test_payload\x00"; + let frame = NotificationResponseFrame::from_bytes(data).unwrap(); + assert_eq!(frame.pid, 1234); + assert_eq!(frame.channel, "test_channel"); + assert_eq!(frame.payload, "test_payload"); + } + + #[test] + fn roundtrip_notification() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = NotificationResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.pid, original.pid); + assert_eq!(decoded.channel, original.channel); + assert_eq!(decoded.payload, original.payload); + } + + #[test] + fn roundtrip_empty_payload() { + let original = make_empty_payload_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = NotificationResponseFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.pid, original.pid); + assert_eq!(decoded.channel, original.channel); + assert_eq!(decoded.payload, original.payload); + } + + #[test] + fn invalid_tag() { + let data = b"B\x00\x00\x00\x22\x00\x00\x04\xD2test_channel\x00test_payload\x00"; + let err = NotificationResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NotificationResponseError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"A\x00\x00\x00\x23\x00\x00\x04\xD2test_channel\x00test_payload\x00"; + let err = NotificationResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NotificationResponseError::UnexpectedLength(_)); + } + + #[test] + fn missing_channel_nul() { + let data = b"A\x00\x00\x00\x1A\x00\x00\x04\xD2test_channeltest_payload\x00"; + let err = NotificationResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NotificationResponseError::UnexpectedEof); + } + + #[test] + fn missing_payload_nul() { + let data = b"A\x00\x00\x00\x21\x00\x00\x04\xD2test_channel\x00test_payload"; + let err = NotificationResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NotificationResponseError::UnexpectedEof); + } + + #[test] + fn extra_data() { + let data = b"A\x00\x00\x00\x22\x00\x00\x04\xD2test_channel\x00test_payload\x00\x00"; + let err = NotificationResponseFrame::from_bytes(data).unwrap_err(); + matches!(err, NotificationResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8_channel() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[9] = 0xFF; // corrupt channel byte + let err = NotificationResponseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, NotificationResponseError::Utf8Error(_)); + } + + #[test] + fn invalid_utf8_payload() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[22] = 0xFF; // corrupt payload byte + let err = NotificationResponseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, NotificationResponseError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/parameter_description.rs b/pgdog/src/wire_protocol/backend/parameter_description.rs new file mode 100644 index 00000000..57c82153 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/parameter_description.rs @@ -0,0 +1,204 @@ +//! Module: wire_protocol::backend::parameter_description +//! +//! Provides parsing and serialization for the ParameterDescription message ('t') in the protocol. +//! +//! - `ParameterDescriptionFrame`: represents the ParameterDescription message with parameter type OIDs. +//! - `ParameterDescriptionError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `ParameterDescriptionFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParameterDescriptionFrame { + pub parameter_oids: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ParameterDescriptionError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedEof, +} + +impl fmt::Display for ParameterDescriptionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParameterDescriptionError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + ParameterDescriptionError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + ParameterDescriptionError::UnexpectedEof => write!(f, "unexpected EOF"), + } + } +} + +impl StdError for ParameterDescriptionError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ParameterDescriptionFrame { + type Error = ParameterDescriptionError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 7 { + return Err(ParameterDescriptionError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b't' { + return Err(ParameterDescriptionError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len < 6 { + return Err(ParameterDescriptionError::UnexpectedLength(len)); + } + if bytes.remaining() != (len - 4) as usize { + return Err(ParameterDescriptionError::UnexpectedLength(len)); + } + + let num_params = bytes.get_i16(); + if num_params < 0 { + return Err(ParameterDescriptionError::UnexpectedLength(len)); + } + let num = num_params as usize; + + if bytes.remaining() != 4 * num { + return Err(ParameterDescriptionError::UnexpectedEof); + } + + let mut parameter_oids = Vec::with_capacity(num); + for _ in 0..num { + parameter_oids.push(bytes.get_u32()); + } + + Ok(ParameterDescriptionFrame { parameter_oids }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + body.put_i16(self.parameter_oids.len() as i16); + for &oid in &self.parameter_oids { + body.put_u32(oid); + } + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b't'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 2 + 4 * self.parameter_oids.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> ParameterDescriptionFrame { + ParameterDescriptionFrame { + parameter_oids: vec![23, 25, 1043], // int4, text, varchar + } + } + + fn make_empty_frame() -> ParameterDescriptionFrame { + ParameterDescriptionFrame { + parameter_oids: vec![], + } + } + + #[test] + fn serialize_parameter_description() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"t\x00\x00\x00\x12\x00\x03\x00\x00\x00\x17\x00\x00\x00\x19\x00\x00\x04\x13"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn serialize_empty() { + let frame = make_empty_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"t\x00\x00\x00\x06\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_parameter_description() { + let data = b"t\x00\x00\x00\x12\x00\x03\x00\x00\x00\x17\x00\x00\x00\x19\x00\x00\x04\x13"; + let frame = ParameterDescriptionFrame::from_bytes(data).unwrap(); + assert_eq!(frame.parameter_oids, vec![23, 25, 1043]); + } + + #[test] + fn deserialize_empty() { + let data = b"t\x00\x00\x00\x06\x00\x00"; + let frame = ParameterDescriptionFrame::from_bytes(data).unwrap(); + assert_eq!(frame.parameter_oids.len(), 0); + } + + #[test] + fn roundtrip_parameter_description() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = ParameterDescriptionFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn roundtrip_empty() { + let original = make_empty_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = ParameterDescriptionFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn invalid_tag() { + let data = b"T\x00\x00\x00\x0E\x00\x03\x00\x00\x00\x17\x00\x00\x00\x19\x00\x00\x04\x13"; + let err = ParameterDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterDescriptionError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"t\x00\x00\x00\x0D\x00\x03\x00\x00\x00\x17\x00\x00\x00\x19\x00\x00\x04\x13"; + let err = ParameterDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterDescriptionError::UnexpectedLength(_)); + } + + #[test] + fn unexpected_eof() { + let data = b"t\x00\x00\x00\x0E\x00\x03\x00\x00\x00\x17\x00\x00\x00\x19"; // missing last OID + let err = ParameterDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterDescriptionError::UnexpectedEof); + } + + #[test] + fn extra_data() { + let data = b"t\x00\x00\x00\x0E\x00\x03\x00\x00\x00\x17\x00\x00\x00\x19\x00\x00\x04\x13\x00"; + let err = ParameterDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterDescriptionError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/parameter_status.rs b/pgdog/src/wire_protocol/backend/parameter_status.rs new file mode 100644 index 00000000..e4a4ed22 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/parameter_status.rs @@ -0,0 +1,234 @@ +//! Module: wire_protocol::backend::parameter_status +//! +//! Provides parsing and serialization for the ParameterStatus message ('S') in the protocol. +//! +//! - `ParameterStatusFrame`: represents the ParameterStatus message with parameter name and value. +//! - `ParameterStatusError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `ParameterStatusFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParameterStatusFrame<'a> { + pub name: &'a str, + pub value: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ParameterStatusError { + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedEof, + Utf8Error(str::Utf8Error), +} + +impl fmt::Display for ParameterStatusError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParameterStatusError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + ParameterStatusError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + ParameterStatusError::UnexpectedEof => write!(f, "unexpected EOF"), + ParameterStatusError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + } + } +} + +impl StdError for ParameterStatusError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + ParameterStatusError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(bytes: &mut &'a [u8]) -> Result<&'a str, ParameterStatusError> { + let nul = bytes + .iter() + .position(|b| *b == 0) + .ok_or(ParameterStatusError::UnexpectedEof)?; + let (raw, rest) = bytes.split_at(nul); + *bytes = &rest[1..]; // skip NUL + Ok(str::from_utf8(raw).map_err(ParameterStatusError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ParameterStatusFrame<'a> { + type Error = ParameterStatusError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 7 { + return Err(ParameterStatusError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'S' { + return Err(ParameterStatusError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(ParameterStatusError::UnexpectedLength(len)); + } + + let mut body = &bytes[5..]; + let name = read_cstr(&mut body)?; + let value = read_cstr(&mut body)?; + + if !body.is_empty() { + return Err(ParameterStatusError::UnexpectedLength(len)); + } + + Ok(ParameterStatusFrame { name, value }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + body.extend_from_slice(self.name.as_bytes()); + body.put_u8(0); + body.extend_from_slice(self.value.as_bytes()); + body.put_u8(0); + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'S'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + self.name.len() + 1 + self.value.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> ParameterStatusFrame<'a> { + ParameterStatusFrame { + name: "client_encoding", + value: "UTF8", + } + } + + fn make_empty_frame<'a>() -> ParameterStatusFrame<'a> { + ParameterStatusFrame { + name: "", + value: "", + } + } + + #[test] + fn serialize_parameter_status() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"S\x00\x00\x00\x19client_encoding\x00UTF8\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn serialize_empty() { + let frame = make_empty_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"S\x00\x00\x00\x06\x00\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_parameter_status() { + let data = b"S\x00\x00\x00\x19client_encoding\x00UTF8\x00"; + let frame = ParameterStatusFrame::from_bytes(data).unwrap(); + assert_eq!(frame.name, "client_encoding"); + assert_eq!(frame.value, "UTF8"); + } + + #[test] + fn deserialize_empty() { + let data = b"S\x00\x00\x00\x06\x00\x00"; + let frame = ParameterStatusFrame::from_bytes(data).unwrap(); + assert_eq!(frame.name, ""); + assert_eq!(frame.value, ""); + } + + #[test] + fn roundtrip_parameter_status() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = ParameterStatusFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn roundtrip_empty() { + let original = make_empty_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = ParameterStatusFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn invalid_tag() { + let data = b"Z\x00\x00\x00\x15client_encoding\x00UTF8\x00"; + let err = ParameterStatusFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterStatusError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"S\x00\x00\x00\x14client_encoding\x00UTF8\x00"; + let err = ParameterStatusFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterStatusError::UnexpectedLength(_)); + } + + #[test] + fn unexpected_eof_name() { + let data = b"S\x00\x00\x00\x15client_encodin"; + let err = ParameterStatusFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterStatusError::UnexpectedEof); + } + + #[test] + fn unexpected_eof_value() { + let data = b"S\x00\x00\x00\x15client_encoding\x00UTF"; + let err = ParameterStatusFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterStatusError::UnexpectedEof); + } + + #[test] + fn extra_data_after() { + let data = b"S\x00\x00\x00\x15client_encoding\x00UTF8\x00\x00"; + let err = ParameterStatusFrame::from_bytes(data).unwrap_err(); + matches!(err, ParameterStatusError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let mut data = make_frame().to_bytes().unwrap().to_vec(); + data[5] = 0xFF; // corrupt name first byte + let err = ParameterStatusFrame::from_bytes(&data).unwrap_err(); + matches!(err, ParameterStatusError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/parse_complete.rs b/pgdog/src/wire_protocol/backend/parse_complete.rs new file mode 100644 index 00000000..d3b3e5ed --- /dev/null +++ b/pgdog/src/wire_protocol/backend/parse_complete.rs @@ -0,0 +1,139 @@ +//! Module: wire_protocol::backend::parse_complete +//! +//! Provides parsing and serialization for the ParseComplete message ('1') in the protocol. +//! +//! - `ParseCompleteFrame`: represents the ParseComplete message indicating parse operation completion. +//! - `ParseCompleteError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `ParseCompleteFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ParseCompleteFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ParseCompleteError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for ParseCompleteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseCompleteError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + ParseCompleteError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for ParseCompleteError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ParseCompleteFrame { + type Error = ParseCompleteError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(ParseCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'1' { + return Err(ParseCompleteError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(ParseCompleteError::UnexpectedLength(len)); + } + + if bytes.len() != 5 { + return Err(ParseCompleteError::UnexpectedLength(bytes.len() as u32)); + } + + Ok(ParseCompleteFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"1\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_parse_complete() { + let frame = ParseCompleteFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"1\x00\x00\x00\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_parse_complete() { + let data = b"1\x00\x00\x00\x04"; + let frame = ParseCompleteFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = frame; + } + + #[test] + fn roundtrip_parse_complete() { + let original = ParseCompleteFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = ParseCompleteFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"2\x00\x00\x00\x04"; + let err = ParseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, ParseCompleteError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"1\x00\x00\x00\x05"; + let err = ParseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, ParseCompleteError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after() { + let data = b"1\x00\x00\x00\x04\x00"; + let err = ParseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, ParseCompleteError::UnexpectedLength(_)); + } + + #[test] + fn short_data() { + let data = b"1\x00\x00\x00"; + let err = ParseCompleteFrame::from_bytes(data).unwrap_err(); + matches!(err, ParseCompleteError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/portal_suspended.rs b/pgdog/src/wire_protocol/backend/portal_suspended.rs new file mode 100644 index 00000000..2ce9047b --- /dev/null +++ b/pgdog/src/wire_protocol/backend/portal_suspended.rs @@ -0,0 +1,126 @@ +//! Module: wire_protocol::backend::portal_suspended +//! +//! Provides parsing and serialization for the PortalSuspended message ('s') in the protocol. +//! +//! - `PortalSuspendedFrame`: represents the PortalSuspended message. +//! - `PortalSuspendedError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `PortalSuspendedFrame`. + +use bytes::Bytes; + +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- + +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PortalSuspendedFrame; + +// ----------------------------------------------------------------------------- + +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum PortalSuspendedError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for PortalSuspendedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PortalSuspendedError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + PortalSuspendedError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for PortalSuspendedError {} + +// ----------------------------------------------------------------------------- + +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for PortalSuspendedFrame { + type Error = PortalSuspendedError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(PortalSuspendedError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b's' { + return Err(PortalSuspendedError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(PortalSuspendedError::UnexpectedLength(len)); + } + + Ok(PortalSuspendedFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"s\x00\x00\x00\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- + +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_portal_suspended() { + let frame = PortalSuspendedFrame; + let bytes = frame.to_bytes().unwrap(); + let expected = b"s\x00\x00\x00\x04"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_portal_suspended() { + let data = b"s\x00\x00\x00\x04"; + let frame = PortalSuspendedFrame::from_bytes(data).unwrap(); + let _ = frame; + } + + #[test] + fn roundtrip_portal_suspended() { + let original = PortalSuspendedFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = PortalSuspendedFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x04"; + let err = PortalSuspendedFrame::from_bytes(data).unwrap_err(); + matches!(err, PortalSuspendedError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"s\x00\x00\x00\x05"; + let err = PortalSuspendedFrame::from_bytes(data).unwrap_err(); + matches!(err, PortalSuspendedError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- + +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/ready_for_query.rs b/pgdog/src/wire_protocol/backend/ready_for_query.rs new file mode 100644 index 00000000..45f2e72f --- /dev/null +++ b/pgdog/src/wire_protocol/backend/ready_for_query.rs @@ -0,0 +1,182 @@ +//! Module: wire_protocol::backend::ready_for_query +//! +//! Provides parsing and serialization for the ReadyForQuery message ('Z') in the protocol. +//! +//! - `ReadyForQueryFrame`: represents the ReadyForQuery message with transaction status. +//! - `ReadyForQueryError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `ReadyForQueryFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ReadyForQueryFrame { + pub status: TransactionStatus, +} + +// ----------------------------------------------------------------------------- +// ----- Subproperties --------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransactionStatus { + Idle, + InTransaction, + InFailedTransaction, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ReadyForQueryError { + UnexpectedTag(u8), + UnexpectedLength(u32), + InvalidStatus(u8), +} + +impl fmt::Display for ReadyForQueryError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ReadyForQueryError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + ReadyForQueryError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + ReadyForQueryError::InvalidStatus(s) => write!(f, "invalid status: {s:#X}"), + } + } +} + +impl StdError for ReadyForQueryError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ReadyForQueryFrame { + type Error = ReadyForQueryError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 6 { + return Err(ReadyForQueryError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'Z' { + return Err(ReadyForQueryError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 5 { + return Err(ReadyForQueryError::UnexpectedLength(len)); + } + + let status_byte = bytes[5]; + let status = match status_byte { + b'I' => TransactionStatus::Idle, + b'T' => TransactionStatus::InTransaction, + b'E' => TransactionStatus::InFailedTransaction, + other => return Err(ReadyForQueryError::InvalidStatus(other)), + }; + + Ok(ReadyForQueryFrame { status }) + } + + fn to_bytes(&self) -> Result { + let status_byte = match self.status { + TransactionStatus::Idle => b'I', + TransactionStatus::InTransaction => b'T', + TransactionStatus::InFailedTransaction => b'E', + }; + let mut buf = BytesMut::with_capacity(6); + buf.put_u8(b'Z'); + buf.put_u32(5); + buf.put_u8(status_byte); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_idle() { + let frame = ReadyForQueryFrame { + status: TransactionStatus::Idle, + }; + let bytes = frame.to_bytes().unwrap(); + let expected = b"Z\x00\x00\x00\x05I"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_idle() { + let data = b"Z\x00\x00\x00\x05I"; + let frame = ReadyForQueryFrame::from_bytes(data).unwrap(); + assert_eq!(frame.status, TransactionStatus::Idle); + } + + #[test] + fn roundtrip_idle() { + let original = ReadyForQueryFrame { + status: TransactionStatus::Idle, + }; + let bytes = original.to_bytes().unwrap(); + let decoded = ReadyForQueryFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.status, original.status); + } + + #[test] + fn roundtrip_in_transaction() { + let original = ReadyForQueryFrame { + status: TransactionStatus::InTransaction, + }; + let bytes = original.to_bytes().unwrap(); + let decoded = ReadyForQueryFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.status, original.status); + } + + #[test] + fn roundtrip_in_failed_transaction() { + let original = ReadyForQueryFrame { + status: TransactionStatus::InFailedTransaction, + }; + let bytes = original.to_bytes().unwrap(); + let decoded = ReadyForQueryFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.status, original.status); + } + + #[test] + fn invalid_tag() { + let data = b"X\x00\x00\x00\x05I"; + let err = ReadyForQueryFrame::from_bytes(data).unwrap_err(); + matches!(err, ReadyForQueryError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"Z\x00\x00\x00\x06I"; + let err = ReadyForQueryFrame::from_bytes(data).unwrap_err(); + matches!(err, ReadyForQueryError::UnexpectedLength(_)); + } + + #[test] + fn invalid_status() { + let data = b"Z\x00\x00\x00\x05X"; + let err = ReadyForQueryFrame::from_bytes(data).unwrap_err(); + matches!(err, ReadyForQueryError::InvalidStatus(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/backend/row_description.rs b/pgdog/src/wire_protocol/backend/row_description.rs new file mode 100644 index 00000000..dd2fe721 --- /dev/null +++ b/pgdog/src/wire_protocol/backend/row_description.rs @@ -0,0 +1,413 @@ +//! Module: wire_protocol::backend::row_description +//! +//! Provides parsing and serialization for the RowDescription message ('T') in the protocol. +//! +//! - `RowDescriptionFrame`: represents the RowDescription message with a list of field descriptions. +//! - `RowDescriptionError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `RowDescriptionFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::shared_property_types::ResultFormat; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq)] +pub struct RowDescriptionFrame<'a> { + pub fields: Vec>, +} + +// ----------------------------------------------------------------------------- +// ----- Subproperties --------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq)] +pub struct RowField<'a> { + pub name: &'a str, + pub table_oid: u32, + pub column_attr: i16, + pub type_oid: u32, + pub type_size: i16, + pub type_modifier: i32, + pub format: ResultFormat, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum RowDescriptionError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidFormatCode(i16), +} + +impl fmt::Display for RowDescriptionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RowDescriptionError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + RowDescriptionError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + RowDescriptionError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + RowDescriptionError::UnexpectedEof => write!(f, "unexpected EOF"), + RowDescriptionError::InvalidFormatCode(c) => write!(f, "invalid format code: {c}"), + } + } +} + +impl StdError for RowDescriptionError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + RowDescriptionError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, RowDescriptionError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(RowDescriptionError::UnexpectedEof)?; + + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + + Ok(str::from_utf8(raw).map_err(RowDescriptionError::Utf8Error)?) +} + +fn decode_format_code(code: i16) -> Result { + match code { + 0 => Ok(ResultFormat::Text), + 1 => Ok(ResultFormat::Binary), + other => Err(RowDescriptionError::InvalidFormatCode(other)), + } +} + +fn encode_format_code(buf: &mut BytesMut, format: ResultFormat) { + buf.put_i16(match format { + ResultFormat::Text => 0, + ResultFormat::Binary => 1, + }); +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for RowDescriptionFrame<'a> { + type Error = RowDescriptionError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(RowDescriptionError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes.get_u8(); + if tag != b'T' { + return Err(RowDescriptionError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len as usize != bytes.remaining() + 4 { + return Err(RowDescriptionError::UnexpectedLength(len)); + } + + let num_fields = bytes.get_i16() as usize; + + let mut fields = Vec::with_capacity(num_fields); + + for _ in 0..num_fields { + let name = read_cstr(&mut bytes)?; + + if bytes.remaining() < 18 { + return Err(RowDescriptionError::UnexpectedEof); + } + + let table_oid = bytes.get_u32(); + let column_attr = bytes.get_i16(); + let type_oid = bytes.get_u32(); + let type_size = bytes.get_i16(); + let type_modifier = bytes.get_i32(); + let format_code = bytes.get_i16(); + let format = decode_format_code(format_code)?; + + fields.push(RowField { + name, + table_oid, + column_attr, + type_oid, + type_size, + type_modifier, + format, + }); + } + + if bytes.has_remaining() { + return Err(RowDescriptionError::UnexpectedLength(len)); + } + + Ok(RowDescriptionFrame { fields }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + + body.put_i16(self.fields.len() as i16); + + for field in &self.fields { + body.extend_from_slice(field.name.as_bytes()); + body.put_u8(0); + body.put_u32(field.table_oid); + body.put_i16(field.column_attr); + body.put_u32(field.type_oid); + body.put_i16(field.type_size); + body.put_i32(field.type_modifier); + encode_format_code(&mut body, field.format); + } + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'T'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 2 + self + .fields + .iter() + .map(|f| f.name.len() + 1 + 4 + 2 + 4 + 2 + 4 + 2) + .sum::() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> RowDescriptionFrame<'a> { + RowDescriptionFrame { + fields: vec![ + RowField { + name: "id", + table_oid: 16384, + column_attr: 1, + type_oid: 23, + type_size: 4, + type_modifier: -1, + format: ResultFormat::Text, + }, + RowField { + name: "name", + table_oid: 16384, + column_attr: 2, + type_oid: 25, + type_size: -1, + type_modifier: -1, + format: ResultFormat::Text, + }, + ], + } + } + + fn make_binary_frame<'a>() -> RowDescriptionFrame<'a> { + RowDescriptionFrame { + fields: vec![RowField { + name: "value", + table_oid: 0, + column_attr: 0, + type_oid: 17, + type_size: -1, + type_modifier: -1, + format: ResultFormat::Binary, + }], + } + } + + #[test] + fn serialize_row_description() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + + let mut expected = Vec::new(); + // Tag + expected.push(b'T'); + // Length (50 bytes total length including length itself) + expected.extend_from_slice(&50u32.to_be_bytes()); + // Number of fields (2) + expected.extend_from_slice(&2i16.to_be_bytes()); + // First field: "id" + expected.extend_from_slice(b"id\0"); + // Table OID (16384) + expected.extend_from_slice(&16384u32.to_be_bytes()); + // Column attribute number (1) + expected.extend_from_slice(&1i16.to_be_bytes()); + // Type OID (23) + expected.extend_from_slice(&23u32.to_be_bytes()); + // Type size (4) + expected.extend_from_slice(&4i16.to_be_bytes()); + // Type modifier (-1) + expected.extend_from_slice(&(-1i32).to_be_bytes()); + // Format (0, Text) + expected.extend_from_slice(&0i16.to_be_bytes()); + // Second field: "name" + expected.extend_from_slice(b"name\0"); + // Table OID (16384) + expected.extend_from_slice(&16384u32.to_be_bytes()); + // Column attribute number (2) + expected.extend_from_slice(&2i16.to_be_bytes()); + // Type OID (25) + expected.extend_from_slice(&25u32.to_be_bytes()); + // Type size (-1) + expected.extend_from_slice(&(-1i16).to_be_bytes()); + // Type modifier (-1) + expected.extend_from_slice(&(-1i32).to_be_bytes()); + // Format (0, Text) + expected.extend_from_slice(&0i16.to_be_bytes()); + + let expected = expected.as_slice(); + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_row_description() { + let mut data = Vec::new(); + + // Tag + data.push(b'T'); + // Length (50 bytes total length including length itself) + data.extend_from_slice(&50u32.to_be_bytes()); + // Number of fields (2) + data.extend_from_slice(&2i16.to_be_bytes()); + // First field: "id" + data.extend_from_slice(b"id\0"); + // Table OID (16384) + data.extend_from_slice(&16384u32.to_be_bytes()); + // Column attribute number (1) + data.extend_from_slice(&1i16.to_be_bytes()); + // Type OID (23) + data.extend_from_slice(&23u32.to_be_bytes()); + // Type size (4) + data.extend_from_slice(&4i16.to_be_bytes()); + // Type modifier (-1) + data.extend_from_slice(&(-1i32).to_be_bytes()); + // Format (0, Text) + data.extend_from_slice(&0i16.to_be_bytes()); + // Second field: "name" + data.extend_from_slice(b"name\0"); + // Table OID (16384) + data.extend_from_slice(&16384u32.to_be_bytes()); + // Column attribute number (2) + data.extend_from_slice(&2i16.to_be_bytes()); + // Type OID (25) + data.extend_from_slice(&25u32.to_be_bytes()); + // Type size (-1) + data.extend_from_slice(&(-1i16).to_be_bytes()); + // Type modifier (-1) + data.extend_from_slice(&(-1i32).to_be_bytes()); + // Format (0, Text) + data.extend_from_slice(&0i16.to_be_bytes()); + + let data = data.as_slice(); + + let frame = RowDescriptionFrame::from_bytes(data).unwrap(); + assert_eq!(frame.fields.len(), 2); + assert_eq!(frame.fields[0].name, "id"); + assert_eq!(frame.fields[0].table_oid, 16384); + assert_eq!(frame.fields[0].column_attr, 1); + assert_eq!(frame.fields[0].type_oid, 23); + assert_eq!(frame.fields[0].type_size, 4); + assert_eq!(frame.fields[0].type_modifier, -1); + assert_eq!(frame.fields[0].format, ResultFormat::Text); + assert_eq!(frame.fields[1].name, "name"); + assert_eq!(frame.fields[1].table_oid, 16384); + assert_eq!(frame.fields[1].column_attr, 2); + assert_eq!(frame.fields[1].type_oid, 25); + assert_eq!(frame.fields[1].type_size, -1); + assert_eq!(frame.fields[1].type_modifier, -1); + assert_eq!(frame.fields[1].format, ResultFormat::Text); + } + + #[test] + fn roundtrip_row_description() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = RowDescriptionFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.fields, original.fields); + } + + #[test] + fn roundtrip_binary() { + let original = make_binary_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = RowDescriptionFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.fields, original.fields); + } + + #[test] + fn invalid_tag() { + let data = b"U\x00\x00\x00\x2F\x00\x02id\x00\x00\x00\x40\x00\x00\x01\x00\x00\x00\x17\x00\x04\xFF\xFF\xFF\xFF\x00\x00name\x00\x00\x00\x40\x00\x00\x02\x00\x00\x00\x19\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x00\x00"; + let err = RowDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, RowDescriptionError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"T\x00\x00\x00\x30\x00\x02id\x00\x00\x00\x40\x00\x00\x01\x00\x00\x00\x17\x00\x04\xFF\xFF\xFF\xFF\x00\x00name\x00\x00\x00\x40\x00\x00\x02\x00\x00\x00\x19\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x00\x00"; + let err = RowDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, RowDescriptionError::UnexpectedLength(_)); + } + + #[test] + fn missing_nul_in_name() { + let data = b"T\x00\x00\x00\x2F\x00\x02id\x00\x00\x00\x40\x00\x00\x01\x00\x00\x00\x17\x00\x04\xFF\xFF\xFF\xFF\x00\x00name"; + let err = RowDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, RowDescriptionError::UnexpectedEof); + } + + #[test] + fn invalid_utf8_name() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[7] = 0xFF; // corrupt name byte + let err = RowDescriptionFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, RowDescriptionError::Utf8Error(_)); + } + + #[test] + fn invalid_format_code() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[25] = 0x00; + bytes[26] = 0x02; // invalid format 2 + let err = RowDescriptionFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, RowDescriptionError::InvalidFormatCode(2)); + } + + #[test] + fn short_field_data() { + let data = b"T\x00\x00\x00\x2E\x00\x02id\x00\x00\x00\x40\x00\x00\x01\x00\x00\x00\x17\x00\x04\xFF\xFF\xFF\xFF\x00\x00name\x00\x00\x00\x40\x00\x00\x02\x00\x00\x00\x19\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x00"; // one byte short + let err = RowDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, RowDescriptionError::UnexpectedEof); + } + + #[test] + fn extra_data() { + let data = b"T\x00\x00\x00\x2F\x00\x02id\x00\x00\x00\x40\x00\x00\x01\x00\x00\x00\x17\x00\x04\xFF\xFF\xFF\xFF\x00\x00name\x00\x00\x00\x40\x00\x00\x02\x00\x00\x00\x19\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x00\x00\x00"; + let err = RowDescriptionFrame::from_bytes(data).unwrap_err(); + matches!(err, RowDescriptionError::UnexpectedLength(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/bidirectional/copy_data.rs b/pgdog/src/wire_protocol/bidirectional/copy_data.rs new file mode 100644 index 00000000..26d50efa --- /dev/null +++ b/pgdog/src/wire_protocol/bidirectional/copy_data.rs @@ -0,0 +1,138 @@ +//! Module: wire_protocol::bidirectional::copy_data +//! +//! Provides parsing and serialization for the CopyData message ('d') in the extended protocol. +//! This message can be used by both the client and the server. +//! +//! - `CopyDataFrame`: represents a CopyData message carrying a chunk of data. +//! - `CopyDataError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CopyDataFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CopyDataFrame<'a> { + pub data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug, PartialEq, Eq)] +pub enum CopyDataError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for CopyDataError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyDataError::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + CopyDataError::UnexpectedLength(len) => write!(f, "unexpected length: {}", len), + } + } +} + +impl StdError for CopyDataError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CopyDataFrame<'a> { + type Error = CopyDataError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(CopyDataError::UnexpectedLength(bytes.len() as u32)); + } + let tag = bytes[0]; + if tag != b'd' { + return Err(CopyDataError::UnexpectedTag(tag)); + } + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(CopyDataError::UnexpectedLength(len)); + } + Ok(CopyDataFrame { data: &bytes[5..] }) + } + + fn to_bytes(&self) -> Result { + let total = 4 + self.data.len(); + let mut buf = BytesMut::with_capacity(1 + total); + buf.put_u8(b'd'); + buf.put_u32(total as u32); + buf.put_slice(self.data); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + fn make_frame() -> CopyDataFrame<'static> { + CopyDataFrame { data: b"hello" } + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = CopyDataFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.data, frame.data); + } + + #[test] + fn empty_payload() { + // an empty chunk is valid: length = 4, no data + let frame = CopyDataFrame { data: &[] }; + let bytes = frame.to_bytes().unwrap(); + assert_eq!(bytes.as_ref(), &[b'd', 0, 0, 0, 4]); + let decoded = CopyDataFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.data, &[] as &[u8]); + } + + #[test] + fn unexpected_tag() { + let mut buf = BytesMut::new(); + buf.put_u8(b'x'); // wrong tag + buf.put_u32(4); + buf.put_slice(b"test"); + let raw = buf.freeze().to_vec(); + let err = CopyDataFrame::from_bytes(raw.as_ref()).unwrap_err(); + assert!(matches!(err, CopyDataError::UnexpectedTag(t) if t == b'x')); + } + + #[test] + fn unexpected_length_mismatch() { + let mut buf = BytesMut::new(); + buf.put_u8(b'd'); + buf.put_u32(10); + buf.put_slice(b"short"); + let raw = buf.freeze().to_vec(); + let err = CopyDataFrame::from_bytes(raw.as_ref()).unwrap_err(); + assert!(matches!(err, CopyDataError::UnexpectedLength(10))); + } + + #[test] + fn unexpected_length_short_buffer() { + let raw = b"d\x00\x00"; // too short + let err = CopyDataFrame::from_bytes(raw).unwrap_err(); + assert!(matches!(err, CopyDataError::UnexpectedLength(len) if len == raw.len() as u32)); + } +} diff --git a/pgdog/src/wire_protocol/bidirectional/copy_done.rs b/pgdog/src/wire_protocol/bidirectional/copy_done.rs new file mode 100644 index 00000000..56f29126 --- /dev/null +++ b/pgdog/src/wire_protocol/bidirectional/copy_done.rs @@ -0,0 +1,120 @@ +//! Module: wire_protocol::frontend::copy_done +//! +//! Provides parsing and serialization for the CopyDone message ('c') in the extended protocol. +//! +//! - `CopyDoneFrame`: represents a CopyDone message sent by the client to indicate the end of COPY data. +//! +//! Implements `WireSerializable` for conversion between raw bytes and `CopyDoneFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CopyDoneFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CopyDoneError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for CopyDoneError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyDoneError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CopyDoneError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for CopyDoneError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CopyDoneFrame { + type Error = CopyDoneError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(CopyDoneError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'c' { + return Err(CopyDoneError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(CopyDoneError::UnexpectedLength(len)); + } + + Ok(CopyDoneFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"c\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_copy_done() { + let copy_done = CopyDoneFrame; + let bytes = copy_done.to_bytes().unwrap(); + let expected_bytes = Bytes::from_static(&[b'c', 0, 0, 0, 4]); + assert_eq!(bytes, expected_bytes); + } + + #[test] + fn deserialize_copy_done() { + let data = &[b'c', 0, 0, 0, 4][..]; + let copy_done = CopyDoneFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = copy_done; + } + + #[test] + fn roundtrip_copy_done() { + let original = CopyDoneFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = CopyDoneFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = &[b'Q', 0, 0, 0, 4][..]; + let err = CopyDoneFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyDoneError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = &[b'c', 0, 0, 0, 5][..]; + let err = CopyDoneFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyDoneError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/bidirectional/mod.rs b/pgdog/src/wire_protocol/bidirectional/mod.rs new file mode 100644 index 00000000..ed60fbbf --- /dev/null +++ b/pgdog/src/wire_protocol/bidirectional/mod.rs @@ -0,0 +1,2 @@ +pub mod copy_data; +pub mod copy_done; diff --git a/pgdog/src/wire_protocol/frontend/bind.rs b/pgdog/src/wire_protocol/frontend/bind.rs new file mode 100644 index 00000000..ef1b823f --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/bind.rs @@ -0,0 +1,356 @@ +//! Module: wire_protocol::frontend::bind +//! +//! Provides parsing and serialization for the Bind message ('B') in the extended protocol. +//! +//! - `BindFrame`: represents a Bind message with portal, statement, parameters, and result formats. +//! - `Parameter`: enum distinguishes between text and binary parameter payloads. +//! - `ResultFormat`: indicates text or binary format for results. +//! - `BindFrameError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `BindFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::shared_property_types::{Parameter, ResultFormat}; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug)] +pub struct BindFrame<'a> { + pub portal: &'a str, + pub statement: &'a str, + pub params: Vec>, + pub result_formats: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum BindFrameError { + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidLength, + InvalidFormatCode(i16), + UnexpectedTag(u8), +} + +impl fmt::Display for BindFrameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindFrameError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + BindFrameError::UnexpectedEof => write!(f, "unexpected EOF"), + BindFrameError::InvalidLength => write!(f, "invalid length or format code"), + BindFrameError::InvalidFormatCode(c) => write!(f, "invalid format code: {c}"), + BindFrameError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + } + } +} + +impl StdError for BindFrameError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + BindFrameError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +#[inline] +fn encode_format_code(buf: &mut BytesMut, is_binary: bool) { + buf.put_i16(if is_binary { 1 } else { 0 }); +} + +#[inline] +fn decode_format_code(code: i16) -> Result { + match code { + 0 => Ok(false), + 1 => Ok(true), + other => Err(BindFrameError::InvalidFormatCode(other)), + } +} + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, BindFrameError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(BindFrameError::UnexpectedEof)?; + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + Ok(str::from_utf8(raw).map_err(BindFrameError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for BindFrame<'a> { + type Error = BindFrameError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(BindFrameError::UnexpectedEof); + } + let tag = bytes.get_u8(); + if tag != b'B' { + return Err(BindFrameError::UnexpectedTag(tag)); + } + let _len = bytes.get_u32(); + + let portal = read_cstr(&mut bytes)?; + let statement = read_cstr(&mut bytes)?; + + // parameter format codes + let fmt_count = bytes.get_i16(); + let mut param_fmts = Vec::with_capacity(fmt_count as usize); + for _ in 0..fmt_count { + param_fmts.push(decode_format_code(bytes.get_i16())?); + } + + // parameters + let param_count = bytes.get_i16() as usize; + let mut params = Vec::with_capacity(param_count); + for idx in 0..param_count { + let val_len = bytes.get_i32(); + let is_binary = match fmt_count { + 0 => false, + 1 => param_fmts[0], + _ => param_fmts[idx], + }; + if val_len == -1 { + params.push(Parameter::Binary(&[])); + continue; + } + let len = val_len as usize; + let slice = &bytes[..len]; + bytes.advance(len); + if is_binary { + params.push(Parameter::Binary(slice)); + } else { + params.push(Parameter::Text( + str::from_utf8(slice).map_err(BindFrameError::Utf8Error)?, + )); + } + } + + // result formats + let res_fmt_count = bytes.get_i16(); + let mut result_formats = Vec::with_capacity(res_fmt_count as usize); + for _ in 0..res_fmt_count { + let is_bin = decode_format_code(bytes.get_i16())?; + result_formats.push(if is_bin { + ResultFormat::Binary + } else { + ResultFormat::Text + }); + } + + Ok(BindFrame { + portal, + statement, + params, + result_formats, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + + // portal\0 + statement\0 + body.extend_from_slice(self.portal.as_bytes()); + body.put_u8(0); + body.extend_from_slice(self.statement.as_bytes()); + body.put_u8(0); + + // param format codes + body.put_i16(self.params.len() as i16); + for p in &self.params { + encode_format_code(&mut body, matches!(p, Parameter::Binary(_))); + } + + // parameter values + body.put_i16(self.params.len() as i16); + for p in &self.params { + match p { + Parameter::Text(s) => { + body.put_i32(s.len() as i32); + body.extend_from_slice(s.as_bytes()); + } + Parameter::Binary(b) if !b.is_empty() => { + body.put_i32(b.len() as i32); + body.extend_from_slice(b); + } + Parameter::Binary(_) => { + body.put_i32(-1); + } + } + } + + // result formats + body.put_i16(self.result_formats.len() as i16); + for fmt in &self.result_formats { + encode_format_code(&mut body, matches!(fmt, ResultFormat::Binary)); + } + + // wrap with tag + length + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'B'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + let mut n = 0; + n += self.portal.len() + 1; + n += self.statement.len() + 1; + n += 2 + self.params.len() * 2; + n += 2; + for p in &self.params { + n += 4; + match p { + Parameter::Text(s) => n += s.len(), + Parameter::Binary(b) => n += b.len(), + } + } + n += 2 + self.result_formats.len() * 2; + n + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> BindFrame<'a> { + BindFrame { + portal: "", + statement: "stmt", + params: vec![Parameter::Text("42")], + result_formats: vec![ResultFormat::Text], + } + } + + fn make_binary_email_frame(email: &str) -> Vec { + let mut body = BytesMut::new(); + + // portal\0 + body.extend_from_slice("".as_bytes()); + body.put_u8(0); + + // statement\0 + body.extend_from_slice("stmt".as_bytes()); + body.put_u8(0); + + // one binary param + body.put_i16(1); // param format count + body.put_i16(1); // format code = binary + body.put_i16(1); // param count + body.put_i32(email.len() as i32); + body.extend_from_slice(email.as_bytes()); + + // no result formats + body.put_i16(0); + + let mut frame = BytesMut::new(); + frame.put_u8(b'B'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + frame.to_vec() + } + + #[test] + fn roundtrip_text_param() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = BindFrame::from_bytes(encoded.as_ref()).unwrap(); + + assert_eq!(decoded.portal, frame.portal); + assert_eq!(decoded.statement, frame.statement); + + match &decoded.params[0] { + Parameter::Text(t) => assert_eq!(*t, "42"), + _ => panic!("expected text param"), + } + + matches!(decoded.result_formats[0], ResultFormat::Text); + } + + #[test] + fn roundtrip_null_param_binary_format() { + let frame = BindFrame { + portal: "super_cool_mega_portal", + statement: "super_cool_mega_statement", + params: vec![Parameter::Binary(&[])], + result_formats: vec![ResultFormat::Binary], + }; + let encoded = frame.to_bytes().unwrap(); + let decoded = BindFrame::from_bytes(encoded.as_ref()).unwrap(); + matches!(decoded.params[0], Parameter::Binary(_)); + matches!(decoded.result_formats[0], ResultFormat::Binary); + } + + #[test] + fn roundtrip_binary_email_param() { + let email = "person@example.com"; + let buf1 = make_binary_email_frame(email); + let frame1 = BindFrame::from_bytes(buf1.as_slice()).unwrap(); + + let raw = if let Parameter::Binary(bytes) = frame1.params[0] { + bytes + } else { + &[] + }; + + assert_eq!(raw, email.as_bytes()); + + let buf2 = frame1.to_bytes().unwrap(); + let frame2 = BindFrame::from_bytes(buf2.as_ref()).unwrap(); + let raw2 = if let Parameter::Binary(b) = frame2.params[0] { + b + } else { + &[] + }; + assert_eq!(raw2, email.as_bytes()); + } + + #[test] + fn invalid_tag() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[0] = b'Q'; // corrupt the tag + + let err = BindFrame::from_bytes(bytes.as_slice()).unwrap_err(); + matches!(err, BindFrameError::UnexpectedTag(_)); + } + + #[test] + fn invalid_format_code() { + // produce a good frame then flip the first format code to 2 + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + + let mut offset = 0; + offset += 5; // header + offset += 0; // portal_name = "" + offset += 1; // NULL terminator + offset += 4; // statement = "stmt" (4 bytes) + offset += 1; // NULL terminator + + bytes[offset + 2] = 0; // count high byte already 0 + bytes[offset + 3] = 2; // invalid code 2 + + let err = BindFrame::from_bytes(bytes.as_slice()).unwrap_err(); + matches!(err, BindFrameError::InvalidFormatCode(2)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/cancel_request.rs b/pgdog/src/wire_protocol/frontend/cancel_request.rs new file mode 100644 index 00000000..f8a59b36 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/cancel_request.rs @@ -0,0 +1,150 @@ +//! Module: wire_protocol::frontend::cancel_request +//! +//! Provides parsing and serialization for the CancelRequest message in the protocol. +//! +//! Note: Unlike regular protocol messages, CancelRequest has no tag byte and is typically +//! sent over a separate connection to interrupt a running query. +//! +//! - `CancelRequestFrame`: represents a CancelRequest message with backend PID and secret key. +//! - `CancelRequestError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CancelRequestFrame`. + +use crate::wire_protocol::WireSerializable; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CancelRequestFrame { + pub pid: i32, + pub secret: i32, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CancelRequestError { + UnexpectedLength(usize), + UnexpectedCode(i32), +} + +impl fmt::Display for CancelRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CancelRequestError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + CancelRequestError::UnexpectedCode(code) => write!(f, "unexpected code: {code}"), + } + } +} + +impl StdError for CancelRequestError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CancelRequestFrame { + type Error = CancelRequestError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() != 16 { + return Err(CancelRequestError::UnexpectedLength(bytes.len())); + } + + let mut buf = bytes; + + let len = buf.get_i32(); + if len != 16 { + return Err(CancelRequestError::UnexpectedLength(len as usize)); + } + + let code = buf.get_i32(); + if code != 80877102 { + return Err(CancelRequestError::UnexpectedCode(code)); + } + + let pid = buf.get_i32(); + let secret = buf.get_i32(); + + Ok(CancelRequestFrame { pid, secret }) + } + + fn to_bytes(&self) -> Result { + let mut buf = BytesMut::with_capacity(16); + buf.put_i32(16); + buf.put_i32(80877102); + buf.put_i32(self.pid); + buf.put_i32(self.secret); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 12 // code + pid + secret + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> CancelRequestFrame { + CancelRequestFrame { + pid: 1234, + secret: 5678, + } + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = CancelRequestFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.pid, frame.pid); + assert_eq!(decoded.secret, frame.secret); + } + + #[test] + fn unexpected_length() { + let mut buf = BytesMut::new(); + buf.put_i32(16); + buf.put_i32(80877102); + buf.put_i32(1234); + // missing secret + let raw = buf.freeze().to_vec(); + let err = CancelRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, CancelRequestError::UnexpectedLength(12)); + } + + #[test] + fn unexpected_code() { + let mut buf = BytesMut::new(); + buf.put_i32(16); + buf.put_i32(999999); + buf.put_i32(1234); + buf.put_i32(5678); + let raw = buf.freeze().to_vec(); + let err = CancelRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, CancelRequestError::UnexpectedCode(999999)); + } + + #[test] + fn unexpected_length_in_message() { + let mut buf = BytesMut::new(); + buf.put_i32(20); // wrong length + buf.put_i32(80877102); + buf.put_i32(1234); + buf.put_i32(5678); + let raw = buf.freeze().to_vec(); + let err = CancelRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, CancelRequestError::UnexpectedLength(20)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/close.rs b/pgdog/src/wire_protocol/frontend/close.rs new file mode 100644 index 00000000..d9766ab4 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/close.rs @@ -0,0 +1,235 @@ +//! Module: wire_protocol::frontend::close +//! +//! Provides parsing and serialization for the Close message ('C') in the extended protocol. +//! +//! - `CloseFrame`: represents a Close message to close a portal or prepared statement. +//! - `CloseTarget`: enum distinguishing between portal and statement. +//! - `CloseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `CloseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, PartialEq, Eq)] +pub struct CloseFrame<'a> { + pub target: CloseTarget, + pub name: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Properties :: CloseTarget --------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CloseTarget { + Portal, + Statement, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CloseError { + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidLength, + InvalidTarget(u8), + UnexpectedTag(u8), +} + +impl fmt::Display for CloseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CloseError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + CloseError::UnexpectedEof => write!(f, "unexpected EOF"), + CloseError::InvalidLength => write!(f, "invalid length"), + CloseError::InvalidTarget(t) => write!(f, "invalid target: {t:#X}"), + CloseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + } + } +} + +impl StdError for CloseError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + CloseError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +#[inline] +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, CloseError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(CloseError::UnexpectedEof)?; + + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + Ok(str::from_utf8(raw).map_err(CloseError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CloseFrame<'a> { + type Error = CloseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 6 { + return Err(CloseError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b'C' { + return Err(CloseError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len < 5 || (len - 4) as usize != bytes.remaining() { + return Err(CloseError::InvalidLength); + } + + let target_byte = bytes.get_u8(); + let target = match target_byte { + b'P' => CloseTarget::Portal, + b'S' => CloseTarget::Statement, + _ => return Err(CloseError::InvalidTarget(target_byte)), + }; + + let name = read_cstr(&mut bytes)?; + + if !bytes.is_empty() { + return Err(CloseError::InvalidLength); + } + + Ok(CloseFrame { target, name }) + } + + fn to_bytes(&self) -> Result { + let body_size = 1 + self.name.len() + 1; + let total_len = 4 + body_size; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'C'); + buf.put_u32(total_len as u32); + + let target_byte = match self.target { + CloseTarget::Portal => b'P', + CloseTarget::Statement => b'S', + }; + buf.put_u8(target_byte); + buf.extend_from_slice(self.name.as_bytes()); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 1 + self.name.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_portal_frame() -> CloseFrame<'static> { + CloseFrame { + target: CloseTarget::Portal, + name: "my_portal", + } + } + + fn make_statement_frame() -> CloseFrame<'static> { + CloseFrame { + target: CloseTarget::Statement, + name: "my_stmt", + } + } + + #[test] + fn roundtrip_portal() { + let frame = make_portal_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = CloseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.target, frame.target); + assert_eq!(decoded.name, frame.name); + } + + #[test] + fn roundtrip_statement() { + let frame = make_statement_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = CloseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.target, frame.target); + assert_eq!(decoded.name, frame.name); + } + + #[test] + fn roundtrip_empty_name() { + let frame = CloseFrame { + target: CloseTarget::Portal, + name: "", + }; + let encoded = frame.to_bytes().unwrap(); + let decoded = CloseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.target, frame.target); + assert_eq!(decoded.name, ""); + } + + #[test] + fn unexpected_tag() { + let mut bytes = make_portal_frame().to_bytes().unwrap().to_vec(); + bytes[0] = b'Q'; // wrong tag + let err = CloseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, CloseError::UnexpectedTag(t) if t == b'Q'); + } + + #[test] + fn invalid_target() { + let mut bytes = make_portal_frame().to_bytes().unwrap().to_vec(); + let offset = 5; // after header + bytes[offset] = b'X'; // invalid target + let err = CloseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, CloseError::InvalidTarget(t) if t == b'X'); + } + + #[test] + fn invalid_length_short() { + let bytes = b"C\x00\x00\x00\x05P"; // length 5, but body should be at least 2 (type + nul) + let err = CloseFrame::from_bytes(bytes).unwrap_err(); + matches!(err, CloseError::UnexpectedEof); + } + + #[test] + fn invalid_length_mismatch() { + let mut bytes = make_portal_frame().to_bytes().unwrap().to_vec(); + bytes[1] = 0xFF; // corrupt length + let err = CloseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, CloseError::InvalidLength); + } + + #[test] + fn unexpected_eof_no_nul() { + let bytes = b"C\x00\x00\x00\x07Pmy"; // no nul terminator + let err = CloseFrame::from_bytes(bytes).unwrap_err(); + matches!(err, CloseError::UnexpectedEof); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/copy_data.rs b/pgdog/src/wire_protocol/frontend/copy_data.rs new file mode 100644 index 00000000..7f97b3ab --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/copy_data.rs @@ -0,0 +1,6 @@ +//! Module: wire_protocol::frontend::copy_data +//! +//! Re-exports the bidirectional CopyDataFrame and CopyDataError +//! to avoid duplicating the implementation. + +pub use crate::wire_protocol::bidirectional::copy_data::{CopyDataError, CopyDataFrame}; diff --git a/pgdog/src/wire_protocol/frontend/copy_done.rs b/pgdog/src/wire_protocol/frontend/copy_done.rs new file mode 100644 index 00000000..db02dd6e --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/copy_done.rs @@ -0,0 +1,6 @@ +//! Module: wire_protocol::frontend::copy_done +//! +//! Re-exports the bidirectional CopyDoneFrame and CopyDoneError +//! to avoid duplicating the implementation. + +pub use crate::wire_protocol::bidirectional::copy_done::{CopyDoneError, CopyDoneFrame}; diff --git a/pgdog/src/wire_protocol/frontend/copy_fail.rs b/pgdog/src/wire_protocol/frontend/copy_fail.rs new file mode 100644 index 00000000..6c40af01 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/copy_fail.rs @@ -0,0 +1,188 @@ +//! Module: wire_protocol::frontend::copy_fail +//! +//! Provides parsing and serialization for the CopyFail message ('f') in the extended protocol. +//! +//! - `CopyFailFrame`: represents a CopyFail message sent by the client to indicate failure during COPY, with an error message. +//! +//! Implements `WireSerializable` for conversion between raw bytes and `CopyFailFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; + +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct CopyFailFrame<'a> { + pub message: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum CopyFailError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, +} + +impl fmt::Display for CopyFailError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyFailError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + CopyFailError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + CopyFailError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + CopyFailError::UnexpectedEof => write!(f, "unexpected EOF"), + } + } +} + +impl StdError for CopyFailError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + CopyFailError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(bytes: &'a [u8]) -> Result<(&'a str, usize), CopyFailError> { + let nul = bytes + .iter() + .position(|b| *b == 0) + .ok_or(CopyFailError::UnexpectedEof)?; + let raw = &bytes[..nul]; + let s = str::from_utf8(raw).map_err(CopyFailError::Utf8Error)?; + Ok((s, nul + 1)) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for CopyFailFrame<'a> { + type Error = CopyFailError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(CopyFailError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'f' { + return Err(CopyFailError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(CopyFailError::UnexpectedLength(len)); + } + + let (message, consumed) = read_cstr(&bytes[5..])?; + if consumed != bytes.len() - 5 { + return Err(CopyFailError::UnexpectedLength(len)); + } + + Ok(CopyFailFrame { message }) + } + + fn to_bytes(&self) -> Result { + let body_len = self.message.len() + 1; + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'f'); + buf.put_u32(total_len as u32); + buf.extend_from_slice(self.message.as_bytes()); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.message.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> CopyFailFrame<'static> { + CopyFailFrame { + message: "failure reason", + } + } + + #[test] + fn serialize_copy_fail() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"f\x00\x00\x00\x13failure reason\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_copy_fail() { + let data = b"f\x00\x00\x00\x13failure reason\x00"; + let frame = CopyFailFrame::from_bytes(data).unwrap(); + assert_eq!(frame.message, "failure reason"); + } + + #[test] + fn roundtrip_copy_fail() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = CopyFailFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.message, original.message); + } + + #[test] + fn invalid_tag() { + let data = b"Q\x00\x00\x00\x13failure reason\x00"; + let err = CopyFailFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyFailError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"f\x00\x00\x00\x14failure reason\x00"; + let err = CopyFailFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyFailError::UnexpectedLength(_)); + } + + #[test] + fn missing_null_terminator() { + let data = b"f\x00\x00\x00\x13failure reason"; + let err = CopyFailFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyFailError::UnexpectedEof); + } + + #[test] + fn extra_data_after_null() { + let data = b"f\x00\x00\x00\x13failure reason\x00extra"; + let err = CopyFailFrame::from_bytes(data).unwrap_err(); + matches!(err, CopyFailError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[10] = 0xFF; // corrupt a byte to invalid UTF-8 + let err = CopyFailFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, CopyFailError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/describe_portal.rs b/pgdog/src/wire_protocol/frontend/describe_portal.rs new file mode 100644 index 00000000..884a0f6b --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/describe_portal.rs @@ -0,0 +1,160 @@ +//! Module: wire_protocol::frontend::describe_portal +//! +//! Purpose: “Describe this portal so I know its row layout.” + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DescribePortalFrame<'a> { + pub name: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum DescribePortalError { + UnexpectedTag(u8), + UnexpectedLength(u32), + NoTerminator, + InvalidUtf8(std::str::Utf8Error), +} + +impl fmt::Display for DescribePortalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + Self::UnexpectedLength(n) => write!(f, "unexpected length: {}", n), + Self::NoTerminator => write!(f, "missing null terminator"), + Self::InvalidUtf8(e) => write!(f, "invalid UTF-8: {}", e), + } + } +} +impl StdError for DescribePortalError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + if let Self::InvalidUtf8(e) = self { + Some(e) + } else { + None + } + } +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for DescribePortalFrame<'a> { + type Error = DescribePortalError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 7 { + return Err(Self::Error::UnexpectedLength(bytes.len() as u32)); + } + if bytes[0] != b'D' { + return Err(Self::Error::UnexpectedTag(bytes[0])); + } + let declared = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if declared as usize != bytes.len() - 1 { + return Err(Self::Error::UnexpectedLength(declared)); + } + if bytes[5] != b'P' { + return Err(Self::Error::UnexpectedTag(bytes[5])); + } + + let rest = &bytes[6..]; + let nul = rest + .iter() + .position(|b| *b == 0) + .ok_or(Self::Error::NoTerminator)?; + if nul + 1 != rest.len() { + return Err(Self::Error::NoTerminator); + } + + let name = std::str::from_utf8(&rest[..nul]).map_err(Self::Error::InvalidUtf8)?; + Ok(Self { name }) + } + + fn to_bytes(&self) -> Result { + let n = self.name.as_bytes(); + let body_len = 1 + n.len() + 1; + let total = 4 + body_len as u32; + + let mut buf = BytesMut::with_capacity(1 + total as usize); + buf.put_u8(b'D'); + buf.put_u32(total); + buf.put_u8(b'P'); + buf.put_slice(n); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 1 + self.name.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame(name: &str) -> DescribePortalFrame { + DescribePortalFrame { name } + } + + #[test] + fn empty_name() { + let empty_name = ""; + let frame = make_frame(empty_name); + let encoded = frame.to_bytes().unwrap(); + let decoded = DescribePortalFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(frame, decoded); + assert_eq!(frame.name, empty_name); + } + + #[test] + fn named() { + let name = "port"; + let frame = make_frame(name); + let encoded = frame.to_bytes().unwrap(); + let decoded = DescribePortalFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(frame, decoded); + assert_eq!(frame.name, name); + } + + #[test] + fn bad_tag_not_describe() { + let name = "_pgdog_portal"; + let mut buf = make_frame(name).to_bytes().unwrap().to_vec(); + buf[0] = b'X'; // corrupt the tag + + assert!(matches!( + DescribePortalFrame::from_bytes(&buf), + Err(DescribePortalError::UnexpectedTag(b'X')) + )); + } + + #[test] + fn bad_kind_tag() { + let name = "_pgdog_portal"; + let mut buf = make_frame(name).to_bytes().unwrap().to_vec(); + buf[5] = b'S'; // overwrite the kind byte ('P' → 'S') + + assert!(matches!( + DescribePortalFrame::from_bytes(&buf), + Err(DescribePortalError::UnexpectedTag(b'S')) + )); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/describe_statement.rs b/pgdog/src/wire_protocol/frontend/describe_statement.rs new file mode 100644 index 00000000..3c5f9f66 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/describe_statement.rs @@ -0,0 +1,160 @@ +//! Module: wire_protocol::frontend::describe_statement +//! +//! Describe Statement ('D' + 'S') + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DescribeStatementFrame<'a> { + pub name: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum DescribeStmtError { + UnexpectedTag(u8), + UnexpectedLength(u32), + NoTerminator, + InvalidUtf8(std::str::Utf8Error), +} + +impl fmt::Display for DescribeStmtError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + Self::UnexpectedLength(n) => write!(f, "unexpected length: {}", n), + Self::NoTerminator => write!(f, "missing null terminator"), + Self::InvalidUtf8(e) => write!(f, "invalid UTF-8: {}", e), + } + } +} +impl StdError for DescribeStmtError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + if let Self::InvalidUtf8(e) = self { + Some(e) + } else { + None + } + } +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for DescribeStatementFrame<'a> { + type Error = DescribeStmtError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 7 { + return Err(Self::Error::UnexpectedLength(bytes.len() as u32)); + } + if bytes[0] != b'D' { + return Err(Self::Error::UnexpectedTag(bytes[0])); + } + let declared = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if declared as usize != bytes.len() - 1 { + return Err(Self::Error::UnexpectedLength(declared)); + } + if bytes[5] != b'S' { + return Err(Self::Error::UnexpectedTag(bytes[5])); + } + + let rest = &bytes[6..]; + let nul = rest + .iter() + .position(|b| *b == 0) + .ok_or(Self::Error::NoTerminator)?; + if nul + 1 != rest.len() { + return Err(Self::Error::NoTerminator); + } + + let name = std::str::from_utf8(&rest[..nul]).map_err(Self::Error::InvalidUtf8)?; + Ok(Self { name }) + } + + fn to_bytes(&self) -> Result { + let n = self.name.as_bytes(); + let body_len = 1 + n.len() + 1; + let total = 4 + body_len as u32; + + let mut buf = BytesMut::with_capacity(1 + total as usize); + buf.put_u8(b'D'); + buf.put_u32(total); + buf.put_u8(b'S'); + buf.put_slice(n); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 1 + self.name.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame(name: &str) -> DescribeStatementFrame { + DescribeStatementFrame { name } + } + + #[test] + fn empty_name() { + let empty_name = ""; + let frame = make_frame(empty_name); + let encoded = frame.to_bytes().unwrap(); + let decoded = DescribeStatementFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(frame, decoded); + assert_eq!(frame.name, empty_name); + } + + #[test] + fn named() { + let name = "port"; + let frame = make_frame(name); + let encoded = frame.to_bytes().unwrap(); + let decoded = DescribeStatementFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(frame, decoded); + assert_eq!(frame.name, name); + } + + #[test] + fn bad_tag_not_describe() { + let name = "_pgdog_statement"; + let mut buf = make_frame(name).to_bytes().unwrap().to_vec(); + buf[0] = b'X'; // corrupt the tag + + assert!(matches!( + DescribeStatementFrame::from_bytes(&buf), + Err(DescribeStmtError::UnexpectedTag(b'X')) + )); + } + + #[test] + fn bad_kind_tag() { + let name = "_pgdog_statement"; + let mut buf = make_frame(name).to_bytes().unwrap().to_vec(); + buf[5] = b'P'; // overwrite the kind byte ('S' → 'P') + + assert!(matches!( + DescribeStatementFrame::from_bytes(&buf), + Err(DescribeStmtError::UnexpectedTag(b'P')) + )); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/execute.rs b/pgdog/src/wire_protocol/frontend/execute.rs new file mode 100644 index 00000000..ea2d22c2 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/execute.rs @@ -0,0 +1,214 @@ +//! Module: wire_protocol::frontend::execute +//! +//! Provides parsing and serialization for the Execute message ('E') in the +//! extended query protocol. +//! +//! - `ExecuteFrame`: represents a request to run a portal, containing the portal +//! name (empty string == unnamed) and `max_rows` (0 == unlimited). +//! - Implements `WireSerializable` for lossless conversion between raw bytes and +//! `ExecuteFrame` instances. +//! +//! Frame layout (FE → BE) +//! 0 -> 'E' (message tag) +//! 1-4 -> Int32 length: size of body + 4 +//! 5...n -> portal name (null-terminated UTF-8) +//! n+1...n+4 -> Int32 `max_rows` (0 == unlimited) + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExecuteFrame<'a> { + pub portal: &'a str, + pub max_rows: u32, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ExecuteError { + Utf8Error(str::Utf8Error), + UnexpectedTag(u8), + UnexpectedLength(u32), + UnexpectedEof, +} + +impl fmt::Display for ExecuteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Utf8Error(e) => write!(f, "utf-8 error: {}", e), + Self::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + Self::UnexpectedLength(n) => write!(f, "unexpected length: {}", n), + Self::UnexpectedEof => write!(f, "unexpected end-of-frame"), + } + } +} +impl StdError for ExecuteError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ExecuteFrame<'a> { + type Error = ExecuteError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + // Need at least tag + len + if bytes.len() < 5 { + return Err(Self::Error::UnexpectedEof); + } + + // Tag check + let tag = bytes.get_u8(); + if tag != b'E' { + return Err(Self::Error::UnexpectedTag(tag)); + } + + // Declared length (body + 4) + let declared = bytes.get_u32(); + if declared as usize != bytes.len() { + return Err(Self::Error::UnexpectedLength(declared)); + } + + // portal (null-terminated) + let null_pos = bytes + .iter() + .position(|b| *b == 0) + .ok_or(Self::Error::UnexpectedEof)?; + + let portal = str::from_utf8(&bytes[..null_pos]).map_err(Self::Error::Utf8Error)?; + + // Remaining after \0 must be exactly 4 bytes for max_rows + if bytes.len() < null_pos + 1 + 4 { + return Err(Self::Error::UnexpectedEof); + } + + let max_rows = (&bytes[null_pos + 1..null_pos + 5]).get_u32(); + + // Extra trailing data? + if bytes.len() != null_pos + 1 + 4 { + return Err(Self::Error::UnexpectedLength(declared)); + } + + Ok(Self { portal, max_rows }) + } + + fn to_bytes(&self) -> Result { + let mut buf = BytesMut::with_capacity(1 + 4 + self.portal.len() + 1 + 4); + buf.put_u8(b'E'); + + let len = (self.portal.len() + 1 + 4) as u32; + buf.put_u32(len); + + buf.put_slice(self.portal.as_bytes()); + buf.put_u8(0); // null terminator + buf.put_u32(self.max_rows); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.portal.len() + 1 + 4 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame<'a>() -> ExecuteFrame<'a> { + ExecuteFrame { + portal: "", + max_rows: 0, + } + } + + fn make_named_frame<'a>() -> ExecuteFrame<'a> { + ExecuteFrame { + portal: "__pgdog_1", + max_rows: 0, + } + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = ExecuteFrame::from_bytes(&encoded).unwrap(); + assert_eq!(decoded, frame); + } + + #[test] + fn named_roundtrip() { + let frame = make_named_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = ExecuteFrame::from_bytes(&encoded).unwrap(); + assert_eq!(decoded, frame); + } + + /// Ensure the encoder writes the correct length field for a named portal. + /// + /// Frame anatomy for an Execute message: + /// Byte 0 : 'E' tag + /// Bytes 1–4 : Int32 length = size of *body* + 4 (Postgres rule) + /// Body layout : portal_name + NUL terminator + max_rows(Int32) + /// + /// For the portal "__pgdog_1": + /// • portal bytes = 9 + /// • NUL terminator = 1 + /// • max_rows field = 4 + /// => body size = 14 + /// => length field must = 14 + /// => total frame size = tag(1) + len(4) + body(14) = 19 bytes + #[test] + fn named_len() { + let frame = make_named_frame(); + let encoded = frame.to_bytes().unwrap(); + + // Pull out the 4-byte length field (big-endian) at bytes 1–4. + let declared_len = u32::from_be_bytes([encoded[1], encoded[2], encoded[3], encoded[4]]); + + // Check the declared length as well as the overall byte count. + assert_eq!(declared_len, 14, "length field should equal body size"); + assert_eq!(encoded.len(), 1 + 4 + 14, "total frame size mismatch"); + } + + #[test] + fn unexpected_tag() { + let mut bad = BytesMut::new(); + bad.put_u8(b'X'); // wrong tag, should be 'E' + bad.put_u32(5); + bad.put_u8(0); + bad.put_u32(0); + + assert!(matches!( + ExecuteFrame::from_bytes(&bad), + Err(ExecuteError::UnexpectedTag(b'X')) + )); + } + + #[test] + fn length_mismatch() { + let mut bad = BytesMut::new(); + bad.put_u8(b'E'); + bad.put_u32(999); // bogus + bad.put_u8(0); + bad.put_u32(0); + + assert!(matches!( + ExecuteFrame::from_bytes(&bad), + Err(ExecuteError::UnexpectedLength(999)) + )); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/flush.rs b/pgdog/src/wire_protocol/frontend/flush.rs new file mode 100644 index 00000000..1f296874 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/flush.rs @@ -0,0 +1,121 @@ +//! Module: wire_protocol::frontend::flush +//! +//! Provides parsing and serialization for the Flush message ('H') in the +//! extended protocol. +//! +//! - `FlushFrame`: represents a Flush message sent by the client to force the +//! backend to deliver any pending results. +//! +//! Implements `WireSerializable` for conversion between raw bytes and +//! `FlushFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- Message --------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct FlushFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum FlushError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for FlushError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FlushError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + FlushError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for FlushError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for FlushFrame { + type Error = FlushError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(FlushError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'H' { + return Err(FlushError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(FlushError::UnexpectedLength(len)); + } + + Ok(FlushFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"H\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_flush() { + let flush = FlushFrame; + let bytes = flush.to_bytes().unwrap(); + let expected = Bytes::from_static(&[b'H', 0, 0, 0, 4]); + assert_eq!(bytes, expected); + } + + #[test] + fn deserialize_flush() { + let data = &[b'H', 0, 0, 0, 4][..]; + let _ = FlushFrame::from_bytes(data).unwrap(); + } + + #[test] + fn roundtrip_flush() { + let original = FlushFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = FlushFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = &[b'Q', 0, 0, 0, 4][..]; + let err = FlushFrame::from_bytes(data).unwrap_err(); + matches!(err, FlushError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = &[b'H', 0, 0, 0, 5][..]; + let err = FlushFrame::from_bytes(data).unwrap_err(); + matches!(err, FlushError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/function_call.rs b/pgdog/src/wire_protocol/frontend/function_call.rs new file mode 100644 index 00000000..e2ce1aa2 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/function_call.rs @@ -0,0 +1,378 @@ +//! Module: wire_protocol::frontend::function_call +//! +//! Provides parsing and serialization for the FunctionCall message ('F') in the extended protocol. +//! +//! - `FunctionCallFrame`: represents a FunctionCall message with function OID, parameters, and result format. +//! - `Parameter`: reused from bind; distinguishes between text and binary parameter payloads. +//! - `ResultFormat`: reused from bind; indicates text or binary format for result. +//! - `FunctionCallFrameError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `FunctionCallFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::shared_property_types::{Parameter, ResultFormat}; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug)] +pub struct FunctionCallFrame<'a> { + pub function_oid: u32, + pub params: Vec>, + pub result_format: ResultFormat, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum FunctionCallFrameError { + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidLength, + InvalidFormatCode(i16), + UnexpectedTag(u8), +} + +impl fmt::Display for FunctionCallFrameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionCallFrameError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + FunctionCallFrameError::UnexpectedEof => write!(f, "unexpected EOF"), + FunctionCallFrameError::InvalidLength => write!(f, "invalid length or format code"), + FunctionCallFrameError::InvalidFormatCode(c) => write!(f, "invalid format code: {c}"), + FunctionCallFrameError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + } + } +} + +impl StdError for FunctionCallFrameError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + FunctionCallFrameError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +#[inline] +fn encode_format_code(buf: &mut BytesMut, is_binary: bool) { + buf.put_i16(if is_binary { 1 } else { 0 }); +} + +#[inline] +fn decode_format_code(code: i16) -> Result { + match code { + 0 => Ok(false), + 1 => Ok(true), + other => Err(FunctionCallFrameError::InvalidFormatCode(other)), + } +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for FunctionCallFrame<'a> { + type Error = FunctionCallFrameError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(FunctionCallFrameError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b'F' { + return Err(FunctionCallFrameError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len < 4 || (len - 4) as usize != bytes.remaining() { + return Err(FunctionCallFrameError::InvalidLength); + } + + if bytes.remaining() < 4 { + return Err(FunctionCallFrameError::UnexpectedEof); + } + let function_oid = bytes.get_u32(); + + // parameter format codes + let fmt_count = bytes.get_i16(); + if fmt_count < 0 { + return Err(FunctionCallFrameError::InvalidLength); + } + + let mut param_fmts = Vec::with_capacity(fmt_count as usize); + for _ in 0..fmt_count { + if bytes.remaining() < 2 { + return Err(FunctionCallFrameError::UnexpectedEof); + } + param_fmts.push(decode_format_code(bytes.get_i16())?); + } + + // parameters + if bytes.remaining() < 2 { + return Err(FunctionCallFrameError::UnexpectedEof); + } + + let param_count = bytes.get_i16() as usize; + let mut params = Vec::with_capacity(param_count); + for idx in 0..param_count { + if bytes.remaining() < 4 { + return Err(FunctionCallFrameError::UnexpectedEof); + } + + let val_len = bytes.get_i32(); + let is_binary = match fmt_count { + 0 => false, + 1 => param_fmts[0], + _ => param_fmts.get(idx).copied().unwrap_or(false), + }; + + if val_len == -1 { + params.push(Parameter::Binary(&[])); + continue; + } + + if val_len < 0 { + return Err(FunctionCallFrameError::InvalidLength); + } + + let len = val_len as usize; + if bytes.remaining() < len { + return Err(FunctionCallFrameError::UnexpectedEof); + } + + let slice = &bytes[..len]; + bytes.advance(len); + + if is_binary { + params.push(Parameter::Binary(slice)); + } else { + params.push(Parameter::Text( + str::from_utf8(slice).map_err(FunctionCallFrameError::Utf8Error)?, + )); + } + } + + // result format + if bytes.remaining() < 2 { + return Err(FunctionCallFrameError::UnexpectedEof); + } + + let result_code = bytes.get_i16(); + let is_bin = decode_format_code(result_code)?; + + let result_format = if is_bin { + ResultFormat::Binary + } else { + ResultFormat::Text + }; + + if bytes.has_remaining() { + return Err(FunctionCallFrameError::InvalidLength); + } + + Ok(FunctionCallFrame { + function_oid, + params, + result_format, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + body.put_u32(self.function_oid); + + // param format codes (always per-param) + body.put_i16(self.params.len() as i16); + for p in &self.params { + encode_format_code(&mut body, matches!(p, Parameter::Binary(_))); + } + + // parameter values + body.put_i16(self.params.len() as i16); + for p in &self.params { + match p { + Parameter::Text(s) => { + body.put_i32(s.len() as i32); + body.extend_from_slice(s.as_bytes()); + } + Parameter::Binary(b) if !b.is_empty() => { + body.put_i32(b.len() as i32); + body.extend_from_slice(b); + } + _ => { + body.put_i32(-1); + } + } + } + + // result format + encode_format_code( + &mut body, + matches!(self.result_format, ResultFormat::Binary), + ); + + // wrap with tag + length + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'F'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + let mut n = 4; // oid + n += 2 + self.params.len() * 2; // formats + n += 2; // param count + for p in &self.params { + n += 4; + match p { + Parameter::Text(s) => n += s.len(), + Parameter::Binary(b) => n += b.len(), + } + } + n += 2; // result format + n + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_text_frame() -> FunctionCallFrame<'static> { + FunctionCallFrame { + function_oid: 1234, + params: vec![Parameter::Text("value")], + result_format: ResultFormat::Text, + } + } + + fn make_binary_null_frame() -> FunctionCallFrame<'static> { + FunctionCallFrame { + function_oid: 5678, + params: vec![Parameter::Binary(&[]), Parameter::Binary(b"\x01\x02")], + result_format: ResultFormat::Binary, + } + } + + #[test] + fn roundtrip_text_param() { + let frame = make_text_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = FunctionCallFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.function_oid, frame.function_oid); + if let Parameter::Text(t) = &decoded.params[0] { + assert_eq!(*t, "value"); + } else { + panic!("expected text param"); + } + assert!(matches!(decoded.result_format, ResultFormat::Text)); + } + + #[test] + fn roundtrip_binary_null() { + let frame = make_binary_null_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = FunctionCallFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.function_oid, frame.function_oid); + assert!(matches!(decoded.params[0], Parameter::Binary(ref b) if b.is_empty())); + assert!(matches!(decoded.params[1], Parameter::Binary(ref b) if b == b"\x01\x02")); + assert!(matches!(decoded.result_format, ResultFormat::Binary)); + } + + #[test] + fn deserialize_with_global_format() { + // Manually craft bytes with fmt_count=1 (all binary) + let mut body = BytesMut::new(); + body.put_u32(1234); // oid + body.put_i16(1); // fmt_count + body.put_i16(1); // binary + body.put_i16(1); // param_count + body.put_i32(2); // len + body.put_slice(b"\x03\x04"); + body.put_i16(0); // result text + let mut frame_bytes = BytesMut::new(); + frame_bytes.put_u8(b'F'); + frame_bytes.put_u32((body.len() + 4) as u32); + frame_bytes.extend_from_slice(&body); + let decoded = FunctionCallFrame::from_bytes(frame_bytes.as_ref()).unwrap(); + assert_eq!(decoded.function_oid, 1234); + assert!(matches!(decoded.params[0], Parameter::Binary(ref b) if b == b"\x03\x04")); + assert!(matches!(decoded.result_format, ResultFormat::Text)); + } + + #[test] + fn deserialize_with_zero_formats() { + // fmt_count=0 (all text) + let mut body = BytesMut::new(); + body.put_u32(1234); + body.put_i16(0); // fmt_count + body.put_i16(1); // param_count + body.put_i32(5); // len + body.put_slice(b"hello"); + body.put_i16(1); // result binary + let mut frame_bytes = BytesMut::new(); + frame_bytes.put_u8(b'F'); + frame_bytes.put_u32((body.len() + 4) as u32); + frame_bytes.extend_from_slice(&body); + let decoded = FunctionCallFrame::from_bytes(frame_bytes.as_ref()).unwrap(); + if let Parameter::Text(t) = &decoded.params[0] { + assert_eq!(*t, "hello"); + } else { + panic!("expected text"); + } + assert!(matches!(decoded.result_format, ResultFormat::Binary)); + } + + #[test] + fn invalid_tag() { + let mut bytes = make_text_frame().to_bytes().unwrap().to_vec(); + bytes[0] = b'Q'; + let err = FunctionCallFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, FunctionCallFrameError::UnexpectedTag(_)); + } + + #[test] + fn invalid_format_code() { + let mut bytes = make_text_frame().to_bytes().unwrap().to_vec(); + // Corrupt a param format code to 2 + // Offset: tag(1)+len(4)+oid(4)+fmt_count(2)=11, then first fmt i16 at 11-12 + bytes[11] = 0; + bytes[12] = 2; // 2 + let err = FunctionCallFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, FunctionCallFrameError::InvalidFormatCode(2)); + } + + #[test] + fn invalid_result_format() { + let mut bytes = make_text_frame().to_bytes().unwrap().to_vec(); + // Last i16 is result format, set to 3 + let last = bytes.len() - 1; + bytes[last - 1] = 0; + bytes[last] = 3; + let err = FunctionCallFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, FunctionCallFrameError::InvalidFormatCode(3)); + } + + #[test] + fn unexpected_eof() { + let bytes = &[b'F', 0, 0, 0, 4][..]; // too short + let err = FunctionCallFrame::from_bytes(bytes).unwrap_err(); + matches!(err, FunctionCallFrameError::UnexpectedEof); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/gss_response.rs b/pgdog/src/wire_protocol/frontend/gss_response.rs new file mode 100644 index 00000000..75d1b08c --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/gss_response.rs @@ -0,0 +1,134 @@ +//! Module: wire_protocol::frontend::gss_response +//! +//! Provides parsing and serialization for the GSSResponse message ('p') used in GSSAPI authentication. +//! +//! This message is sent by the client in response to AuthenticationGSS or AuthenticationSSPI. +//! +//! - `GssResponseFrame`: represents a GSSResponse message carrying a chunk of GSS/SSPI data. +//! - `GssResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `GssResponseFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct GssResponseFrame<'a> { + pub data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum GssResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for GssResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + GssResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + GssResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {}", len), + } + } +} + +impl StdError for GssResponseError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for GssResponseFrame<'a> { + type Error = GssResponseError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(GssResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'p' { + return Err(GssResponseError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(GssResponseError::UnexpectedLength(len)); + } + + Ok(GssResponseFrame { data: &bytes[5..] }) + } + + fn to_bytes(&self) -> Result { + let total = 4 + self.data.len(); + let mut buf = BytesMut::with_capacity(1 + total); + buf.put_u8(b'p'); + buf.put_u32(total as u32); + buf.put_slice(self.data); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use bytes::{BufMut, BytesMut}; + + fn make_frame() -> GssResponseFrame<'static> { + GssResponseFrame { data: b"gsstoken" } + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = GssResponseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.data, frame.data); + } + + #[test] + fn unexpected_tag() { + let mut buf = BytesMut::new(); + buf.put_u8(b'x'); // wrong tag, should be b'p' + buf.put_u32(4 + 4); + buf.put_slice(b"test"); + let raw = buf.freeze().to_vec(); + let err = GssResponseFrame::from_bytes(raw.as_ref()).unwrap_err(); + matches!(err, GssResponseError::UnexpectedTag(t) if t == b'x'); + } + + #[test] + fn unexpected_length_mismatch() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(10); + buf.put_slice(b"short"); + let raw = buf.freeze().to_vec(); + let err = GssResponseFrame::from_bytes(raw.as_ref()).unwrap_err(); + matches!(err, GssResponseError::UnexpectedLength(10)); + } + + #[test] + fn unexpected_length_short_buffer() { + let raw = b"p\x00\x00"; // too short to contain length + data + let err = GssResponseFrame::from_bytes(raw).unwrap_err(); + matches!(err, GssResponseError::UnexpectedLength(len) if len == raw.len() as u32); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/gssenc_request.rs b/pgdog/src/wire_protocol/frontend/gssenc_request.rs new file mode 100644 index 00000000..07c7242a --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/gssenc_request.rs @@ -0,0 +1,130 @@ +//! Module: wire_protocol::frontend::gssenc_request +//! +//! Provides parsing and serialization for the GSSENCRequest message in the protocol. +//! +//! Note: Unlike regular protocol messages, GSSENCRequest has no tag byte and is sent +//! by the client to request GSSAPI encryption during startup. +//! +//! - `GssencRequestFrame`: represents the GSSENCRequest message. +//! - `GssencRequestError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `GssencRequestFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GssencRequestFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum GssencRequestError { + UnexpectedLength(usize), + UnexpectedCode(i32), +} + +impl fmt::Display for GssencRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + GssencRequestError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + GssencRequestError::UnexpectedCode(code) => write!(f, "unexpected code: {code}"), + } + } +} + +impl StdError for GssencRequestError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for GssencRequestFrame { + type Error = GssencRequestError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() != 8 { + return Err(GssencRequestError::UnexpectedLength(bytes.len())); + } + + let mut buf = bytes; + + let len = buf.get_i32(); + if len != 8 { + return Err(GssencRequestError::UnexpectedLength(len as usize)); + } + + let code = buf.get_i32(); + if code != 80877104 { + return Err(GssencRequestError::UnexpectedCode(code)); + } + + Ok(GssencRequestFrame) + } + + fn to_bytes(&self) -> Result { + let mut buf = BytesMut::with_capacity(8); + buf.put_i32(8); + buf.put_i32(80877104); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 // code + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip() { + let frame = GssencRequestFrame; + let encoded = frame.to_bytes().unwrap(); + let decoded = GssencRequestFrame::from_bytes(encoded.as_ref()).unwrap(); + // no state; just ensure no error + let _ = decoded; + } + + #[test] + fn unexpected_length() { + let mut buf = BytesMut::new(); + buf.put_i32(8); + // missing code + let raw = buf.freeze().to_vec(); + let err = GssencRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, GssencRequestError::UnexpectedLength(4)); + } + + #[test] + fn unexpected_code() { + let mut buf = BytesMut::new(); + buf.put_i32(8); + buf.put_i32(999999); + let raw = buf.freeze().to_vec(); + let err = GssencRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, GssencRequestError::UnexpectedCode(999999)); + } + + #[test] + fn unexpected_length_in_message() { + let mut buf = BytesMut::new(); + buf.put_i32(12); // wrong length + buf.put_i32(80877104); + let raw = buf.freeze().to_vec(); + let err = GssencRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, GssencRequestError::UnexpectedLength(12)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/mod.rs b/pgdog/src/wire_protocol/frontend/mod.rs new file mode 100644 index 00000000..b7983475 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/mod.rs @@ -0,0 +1,116 @@ +pub mod bind; +pub mod cancel_request; +pub mod close; +pub mod copy_data; +pub mod copy_done; +pub mod copy_fail; +pub mod describe_portal; +pub mod describe_statement; +pub mod execute; +pub mod flush; +pub mod function_call; +pub mod gss_response; +pub mod gssenc_request; +pub mod parse; +pub mod password_message; +pub mod query; +pub mod sasl_initial_response; +pub mod ssl_request; +pub mod sspi_response; +pub mod startup; +pub mod sync; +pub mod terminate; + +use crate::wire_protocol::frontend::bind::BindFrame; +use crate::wire_protocol::frontend::cancel_request::CancelRequestFrame; +use crate::wire_protocol::frontend::close::CloseFrame; +use crate::wire_protocol::frontend::copy_data::CopyDataFrame; +use crate::wire_protocol::frontend::copy_done::CopyDoneFrame; +use crate::wire_protocol::frontend::copy_fail::CopyFailFrame; +use crate::wire_protocol::frontend::describe_portal::DescribePortalFrame; +use crate::wire_protocol::frontend::describe_statement::DescribeStatementFrame; +use crate::wire_protocol::frontend::execute::ExecuteFrame; +use crate::wire_protocol::frontend::flush::FlushFrame; +use crate::wire_protocol::frontend::function_call::FunctionCallFrame; +use crate::wire_protocol::frontend::gss_response::GssResponseFrame; +use crate::wire_protocol::frontend::gssenc_request::GssencRequestFrame; +use crate::wire_protocol::frontend::parse::ParseFrame; +use crate::wire_protocol::frontend::password_message::PasswordMessageFrame; +use crate::wire_protocol::frontend::query::QueryFrame; +use crate::wire_protocol::frontend::sasl_initial_response::SaslInitialResponseFrame; +use crate::wire_protocol::frontend::ssl_request::SslRequestFrame; +use crate::wire_protocol::frontend::sspi_response::SspiResponseFrame; +use crate::wire_protocol::frontend::startup::StartupFrame; +use crate::wire_protocol::frontend::sync::SyncFrame; +use crate::wire_protocol::frontend::terminate::TerminateFrame; + +/// Represents any frontend-initiated protocol message. +/// Bidirectional protocol messages are also included. +#[derive(Debug)] +pub enum FrontendProtocolMessage<'a> { + /// Extended-protocol Bind message + Bind(BindFrame<'a>), + + /// CancelRequest message for canceling queries + CancelRequest(CancelRequestFrame), + + /// Close message + Close(CloseFrame<'a>), + + /// CopyData message for COPY operations + CopyData(CopyDataFrame<'a>), + + /// CopyDone message for COPY operations + CopyDone(CopyDoneFrame), + + /// CopyFail message for COPY operations + CopyFail(CopyFailFrame<'a>), + + /// Extended-protocol DescribePortal message for describing portals + DescribePortal(DescribePortalFrame<'a>), + + /// Extended-protocol DescribeStatement message for describing prepared statements + DescribeStatement(DescribeStatementFrame<'a>), + + /// Extended-protocol Execute message for executing prepared statements + Execute(ExecuteFrame<'a>), + + /// Flush message for flushing data + Flush(FlushFrame), + + /// FunctionCall message for calling functions + FunctionCall(FunctionCallFrame<'a>), + + /// GssResponse message for GSSAPI authentication + GssResponse(GssResponseFrame<'a>), + + /// GssEncRequest message for GSSAPI encryption + GssEncRequest(GssencRequestFrame), + + /// Parse message for parsing SQL statements + Parse(ParseFrame<'a>), + + /// Password message for password authentication + Password(PasswordMessageFrame<'a>), + + /// Query message for executing SQL queries + Query(QueryFrame<'a>), + + /// SaslInitialResponse message for SASL authentication + SaslInitialResponse(SaslInitialResponseFrame<'a>), + + /// SslRequest message for SSL negotiation + SslRequest(SslRequestFrame), + + /// SspiResponse message for SSPI authentication + SspiResponse(SspiResponseFrame<'a>), + + /// Startup message for initializing the connection + Startup(StartupFrame<'a>), + + /// Sync message for synchronizing the connection + Sync(SyncFrame), + + /// Terminate message signaling session end + Terminate(TerminateFrame), +} diff --git a/pgdog/src/wire_protocol/frontend/parse.rs b/pgdog/src/wire_protocol/frontend/parse.rs new file mode 100644 index 00000000..aca01e40 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/parse.rs @@ -0,0 +1,233 @@ +//! Module: wire_protocol::frontend::parse +//! +//! Provides parsing and serialization for the Parse message ('P') in the extended protocol. +//! +//! - `ParseFrame`: represents a Parse message with statement name, query, and parameter type OIDs. +//! - `ParseFrameError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `ParseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug)] +pub struct ParseFrame<'a> { + pub statement: &'a str, + pub query: &'a str, + pub param_types: Vec, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum ParseFrameError { + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidLength, + UnexpectedTag(u8), +} + +impl fmt::Display for ParseFrameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseFrameError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + ParseFrameError::UnexpectedEof => write!(f, "unexpected EOF"), + ParseFrameError::InvalidLength => write!(f, "invalid length"), + ParseFrameError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + } + } +} + +impl StdError for ParseFrameError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + ParseFrameError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ParseFrameError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(ParseFrameError::UnexpectedEof)?; + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + Ok(str::from_utf8(raw).map_err(ParseFrameError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for ParseFrame<'a> { + type Error = ParseFrameError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(ParseFrameError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b'P' { + return Err(ParseFrameError::UnexpectedTag(tag)); + } + + let _len = bytes.get_u32(); + + let statement = read_cstr(&mut bytes)?; + + let query = read_cstr(&mut bytes)?; + + let num_params = bytes.get_i16() as usize; + + let mut param_types = Vec::with_capacity(num_params); + for _ in 0..num_params { + if bytes.remaining() < 4 { + return Err(ParseFrameError::UnexpectedEof); + } + param_types.push(bytes.get_u32()); + } + + if bytes.has_remaining() { + return Err(ParseFrameError::InvalidLength); + } + + Ok(ParseFrame { + statement, + query, + param_types, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + + body.extend_from_slice(self.statement.as_bytes()); + body.put_u8(0); + + body.extend_from_slice(self.query.as_bytes()); + body.put_u8(0); + + body.put_i16(self.param_types.len() as i16); + + for &oid in &self.param_types { + body.put_u32(oid); + } + + let mut frame = BytesMut::with_capacity(body.len() + 5); + frame.put_u8(b'P'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + let mut n = self.statement.len() + 1; + n += self.query.len() + 1; + n += 2; + n += self.param_types.len() * 4; + n + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame_no_params() -> ParseFrame<'static> { + ParseFrame { + statement: "stmt", + query: "SELECT 1", + param_types: vec![], + } + } + + fn make_frame_with_params() -> ParseFrame<'static> { + ParseFrame { + statement: "", + query: "SELECT $1::text", + param_types: vec![25], + } + } + + #[test] + fn roundtrip_no_params() { + let frame = make_frame_no_params(); + let encoded = frame.to_bytes().unwrap(); + let decoded = ParseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.statement, frame.statement); + assert_eq!(decoded.query, frame.query); + assert_eq!(decoded.param_types, frame.param_types); + } + + #[test] + fn roundtrip_with_params() { + let frame = make_frame_with_params(); + let encoded = frame.to_bytes().unwrap(); + let decoded = ParseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.statement, frame.statement); + assert_eq!(decoded.query, frame.query); + assert_eq!(decoded.param_types, frame.param_types); + } + + #[test] + fn invalid_tag() { + let mut bytes = make_frame_no_params().to_bytes().unwrap().to_vec(); + bytes[0] = b'Q'; + let err = ParseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, ParseFrameError::UnexpectedTag(_)); + } + + #[test] + fn unexpected_eof_in_params() { + let mut body = BytesMut::new(); + body.extend_from_slice("stmt".as_bytes()); + body.put_u8(0); + body.extend_from_slice("SELECT 1".as_bytes()); + body.put_u8(0); + body.put_i16(1); // one param + // missing the u32 oid + let mut frame = BytesMut::new(); + frame.put_u8(b'P'); + frame.put_u32((body.len() + 4) as u32); + frame.extend_from_slice(&body); + let err = ParseFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, ParseFrameError::UnexpectedEof); + } + + #[test] + fn extra_data() { + let mut bytes = make_frame_no_params().to_bytes().unwrap().to_vec(); + bytes.push(0); // extra byte + // but length is fixed, so adjust len to match + let len_bytes = (bytes.len() - 1) as u32; + bytes[1..5].copy_from_slice(&len_bytes.to_be_bytes()); + let err = ParseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, ParseFrameError::InvalidLength); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_frame_no_params().to_bytes().unwrap().to_vec(); + bytes[5] = 0xFF; // corrupt first byte of statement + let err = ParseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, ParseFrameError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/password_message.rs b/pgdog/src/wire_protocol/frontend/password_message.rs new file mode 100644 index 00000000..515c9294 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/password_message.rs @@ -0,0 +1,219 @@ +//! Module: wire_protocol::frontend::password +//! +//! Provides parsing and serialization for the PasswordMessage message ('p') used in password authentication. +//! +//! This message is sent by the client in response to AuthenticationCleartextPassword, AuthenticationMD5Password, etc. +//! +//! - `PasswordMessageFrame`: represents a PasswordMessage message carrying the password as a null-terminated string. +//! - `PasswordMessageError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `PasswordMessageFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct PasswordMessageFrame<'a> { + pub password: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum PasswordMessageError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + ExtraDataAfterNull, +} + +impl fmt::Display for PasswordMessageError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PasswordMessageError::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + PasswordMessageError::UnexpectedLength(len) => write!(f, "unexpected length: {}", len), + PasswordMessageError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + PasswordMessageError::UnexpectedEof => write!(f, "unexpected EOF"), + PasswordMessageError::ExtraDataAfterNull => { + write!(f, "extra data after null terminator") + } + } + } +} + +impl StdError for PasswordMessageError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + PasswordMessageError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(bytes: &'a [u8]) -> Result<(&'a str, usize), PasswordMessageError> { + let nul_pos = bytes + .iter() + .position(|b| *b == 0) + .ok_or(PasswordMessageError::UnexpectedEof)?; + let raw = &bytes[..nul_pos]; + let s = str::from_utf8(raw).map_err(PasswordMessageError::Utf8Error)?; + Ok((s, nul_pos + 1)) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for PasswordMessageFrame<'a> { + type Error = PasswordMessageError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(PasswordMessageError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'p' { + return Err(PasswordMessageError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(PasswordMessageError::UnexpectedLength(len)); + } + + let (password, consumed) = read_cstr(&bytes[5..])?; + if consumed != bytes.len() - 5 { + return Err(PasswordMessageError::ExtraDataAfterNull); + } + + Ok(PasswordMessageFrame { password }) + } + + fn to_bytes(&self) -> Result { + let body_len = self.password.len() + 1; // include \0 + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'p'); + buf.put_u32(total_len as u32); + buf.extend_from_slice(self.password.as_bytes()); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.password.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use bytes::{BufMut, BytesMut}; + + fn make_frame() -> PasswordMessageFrame<'static> { + PasswordMessageFrame { password: "secret" } + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = PasswordMessageFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.password, frame.password); + } + + #[test] + fn serialize_with_null() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"p\x00\x00\x00\x0Bsecret\x00"; // length=11 (4 + 6 + 1) + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_with_null() { + let data = b"p\x00\x00\x00\x0Bsecret\x00"; + let frame = PasswordMessageFrame::from_bytes(data).unwrap(); + assert_eq!(frame.password, "secret"); + } + + #[test] + fn unexpected_tag() { + let mut buf = BytesMut::new(); + buf.put_u8(b'x'); // wrong tag + buf.put_u32(4 + 7); + buf.put_slice(b"secret\x00"); + let raw = buf.freeze().to_vec(); + let err = PasswordMessageFrame::from_bytes(&raw).unwrap_err(); + matches!(err, PasswordMessageError::UnexpectedTag(t) if t == b'x'); + } + + #[test] + fn unexpected_length_mismatch() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(10); // claims 10 (4+6), but body=5 ("short" no \0) + buf.put_slice(b"short"); + let raw = buf.freeze().to_vec(); + let err = PasswordMessageFrame::from_bytes(&raw).unwrap_err(); + matches!(err, PasswordMessageError::UnexpectedLength(10)); + } + + #[test] + fn unexpected_length_short_buffer() { + let raw = b"p\x00\x00"; // too short + let err = PasswordMessageFrame::from_bytes(raw).unwrap_err(); + matches!(err, PasswordMessageError::UnexpectedLength(len) if len == raw.len() as u32); + } + + #[test] + fn missing_null_terminator() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(4 + 6); // length for "secret" without \0 + buf.put_slice(b"secret"); + let raw = buf.freeze().to_vec(); + let err = PasswordMessageFrame::from_bytes(&raw).unwrap_err(); + matches!(err, PasswordMessageError::UnexpectedEof); + } + + #[test] + fn invalid_utf8() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(4 + 3); // invalid UTF-8 + \0 + buf.put_slice(&[0xFF, b'a', 0]); + let raw = buf.freeze().to_vec(); + let err = PasswordMessageFrame::from_bytes(&raw).unwrap_err(); + matches!(err, PasswordMessageError::Utf8Error(_)); + } + + #[test] + fn embedded_null() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(4 + 8); // length for full body with embedded \0 + buf.put_slice(b"sec\x00ret\x00"); + let raw = buf.freeze().to_vec(); + let err = PasswordMessageFrame::from_bytes(&raw).unwrap_err(); + matches!(err, PasswordMessageError::ExtraDataAfterNull); // errors on extra data after first \0 + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/query.rs b/pgdog/src/wire_protocol/frontend/query.rs new file mode 100644 index 00000000..bfc4fffa --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/query.rs @@ -0,0 +1,187 @@ +//! Module: wire_protocol::frontend::query +//! +//! Provides parsing and serialization for the Query message ('Q') in the simple query protocol. +//! +//! - `QueryFrame`: represents a Query message with the SQL query string. +//! - `QueryError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `QueryFrame`. + +use bytes::{BufMut, Bytes, BytesMut}; + +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct QueryFrame<'a> { + pub query: &'a str, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum QueryError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, +} + +impl fmt::Display for QueryError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + QueryError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + QueryError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + QueryError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + QueryError::UnexpectedEof => write!(f, "unexpected EOF"), + } + } +} + +impl StdError for QueryError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + QueryError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(bytes: &'a [u8]) -> Result<(&'a str, usize), QueryError> { + let nul = bytes + .iter() + .position(|b| *b == 0) + .ok_or(QueryError::UnexpectedEof)?; + let raw = &bytes[..nul]; + let s = str::from_utf8(raw).map_err(QueryError::Utf8Error)?; + Ok((s, nul + 1)) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for QueryFrame<'a> { + type Error = QueryError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(QueryError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'Q' { + return Err(QueryError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(QueryError::UnexpectedLength(len)); + } + + let (query, consumed) = read_cstr(&bytes[5..])?; + if consumed != bytes.len() - 5 { + return Err(QueryError::UnexpectedLength(len)); + } + + Ok(QueryFrame { query }) + } + + fn to_bytes(&self) -> Result { + let body_len = self.query.len() + 1; + let total_len = 4 + body_len; + + let mut buf = BytesMut::with_capacity(1 + total_len); + buf.put_u8(b'Q'); + buf.put_u32(total_len as u32); + buf.extend_from_slice(self.query.as_bytes()); + buf.put_u8(0); + + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.query.len() + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> QueryFrame<'static> { + QueryFrame { query: "SELECT 1" } + } + + #[test] + fn serialize_query() { + let frame = make_frame(); + let bytes = frame.to_bytes().unwrap(); + let expected = b"Q\x00\x00\x00\x0DSELECT 1\x00"; + assert_eq!(bytes.as_ref(), expected); + } + + #[test] + fn deserialize_query() { + let data = b"Q\x00\x00\x00\x0DSELECT 1\x00"; + let frame = QueryFrame::from_bytes(data).unwrap(); + assert_eq!(frame.query, "SELECT 1"); + } + + #[test] + fn roundtrip_query() { + let original = make_frame(); + let bytes = original.to_bytes().unwrap(); + let decoded = QueryFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(decoded.query, original.query); + } + + #[test] + fn invalid_tag() { + let data = b"P\x00\x00\x00\x0ASELECT 1\x00"; + let err = QueryFrame::from_bytes(data).unwrap_err(); + matches!(err, QueryError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = b"Q\x00\x00\x00\x0BSELECT 1\x00"; + let err = QueryFrame::from_bytes(data).unwrap_err(); + matches!(err, QueryError::UnexpectedLength(_)); + } + + #[test] + fn missing_null_terminator() { + let data = b"Q\x00\x00\x00\x0ASELECT 1"; + let err = QueryFrame::from_bytes(data).unwrap_err(); + matches!(err, QueryError::UnexpectedEof); + } + + #[test] + fn extra_data_after_null() { + let data = b"Q\x00\x00\x00\x0ASELECT 1\x00extra"; + let err = QueryFrame::from_bytes(data).unwrap_err(); + matches!(err, QueryError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes[5] = 0xFF; // corrupt first byte + let err = QueryFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, QueryError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/sasl_initial_response.rs b/pgdog/src/wire_protocol/frontend/sasl_initial_response.rs new file mode 100644 index 00000000..6a968352 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/sasl_initial_response.rs @@ -0,0 +1,294 @@ +//! Module: wire_protocol::frontend::sasl_initial_response +//! +//! Provides parsing and serialization for the SASLInitialResponse message ('p') in the extended protocol. +//! +//! - `SaslInitialResponseFrame`: represents the initial SASL response with mechanism name and optional initial data. +//! - `SaslInitialResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `SaslInitialResponseFrame`. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::shared_property_types::sasl_mechanism::{ + SaslMechanism, SaslMechanismError, +}; +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct SaslInitialResponseFrame<'a> { + pub mechanism: SaslMechanism, + pub initial_data: Option<&'a [u8]>, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SaslInitialResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidDataLength, + MechanismError(SaslMechanismError), +} + +impl fmt::Display for SaslInitialResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SaslInitialResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + SaslInitialResponseError::UnexpectedLength(len) => { + write!(f, "unexpected length: {len}") + } + SaslInitialResponseError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + SaslInitialResponseError::UnexpectedEof => write!(f, "unexpected EOF"), + SaslInitialResponseError::InvalidDataLength => write!(f, "invalid data length"), + SaslInitialResponseError::MechanismError(e) => write!(f, "SASL mechanism error: {e}"), + } + } +} + +impl StdError for SaslInitialResponseError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + SaslInitialResponseError::Utf8Error(e) => Some(e), + SaslInitialResponseError::MechanismError(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, SaslInitialResponseError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(SaslInitialResponseError::UnexpectedEof)?; + + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + + Ok(str::from_utf8(raw).map_err(SaslInitialResponseError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for SaslInitialResponseFrame<'a> { + type Error = SaslInitialResponseError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + if bytes.remaining() < 5 { + return Err(SaslInitialResponseError::UnexpectedEof); + } + + let tag = bytes.get_u8(); + if tag != b'p' { + return Err(SaslInitialResponseError::UnexpectedTag(tag)); + } + + let len = bytes.get_u32(); + if len as usize != bytes.remaining() + 4 { + // len includes itself + return Err(SaslInitialResponseError::UnexpectedLength(len)); + } + + let mechanism_str = read_cstr(&mut bytes)?; + + let mechanism = SaslMechanism::from_str(mechanism_str) + .map_err(SaslInitialResponseError::MechanismError)?; + + if bytes.remaining() < 4 { + return Err(SaslInitialResponseError::UnexpectedEof); + } + + let data_len_i32 = bytes.get_i32(); + + let initial_data = if data_len_i32 == -1 { + if bytes.has_remaining() { + return Err(SaslInitialResponseError::UnexpectedLength(len)); + } + None + } else { + if data_len_i32 < 0 { + return Err(SaslInitialResponseError::InvalidDataLength); + } + let data_len = data_len_i32 as usize; + if bytes.remaining() != data_len { + return Err(SaslInitialResponseError::UnexpectedLength(len)); + } + let data = &bytes[0..data_len]; + bytes.advance(data_len); + Some(data) + }; + + Ok(SaslInitialResponseFrame { + mechanism, + initial_data, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::new(); + body.extend_from_slice(self.mechanism.as_str().as_bytes()); + body.put_u8(0); + match &self.initial_data { + Some(data) => { + body.put_i32(data.len() as i32); + body.extend_from_slice(data); + } + None => { + body.put_i32(-1); + } + } + + let mut frame = BytesMut::with_capacity(5 + body.len()); + frame.put_u8(b'p'); + frame.put_u32((4 + body.len()) as u32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + self.mechanism.as_str().len() + 1 + 4 + self.initial_data.map_or(0, |d| d.len()) + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame_with_data() -> SaslInitialResponseFrame<'static> { + SaslInitialResponseFrame { + mechanism: SaslMechanism::ScramSha256, + initial_data: Some(b"n,,n=user,r=3D3D3D"), + } + } + + fn make_frame_no_data() -> SaslInitialResponseFrame<'static> { + SaslInitialResponseFrame { + mechanism: SaslMechanism::ScramSha256Plus, + initial_data: None, + } + } + + #[test] + fn roundtrip_with_data() { + let frame = make_frame_with_data(); + let encoded = frame.to_bytes().unwrap(); + let decoded = SaslInitialResponseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.mechanism, frame.mechanism); + assert_eq!(decoded.initial_data, frame.initial_data); + } + + #[test] + fn roundtrip_no_data() { + let frame = make_frame_no_data(); + let encoded = frame.to_bytes().unwrap(); + let decoded = SaslInitialResponseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.mechanism, frame.mechanism); + assert_eq!(decoded.initial_data, frame.initial_data); + } + + #[test] + fn invalid_tag() { + let mut bytes = make_frame_with_data().to_bytes().unwrap().to_vec(); + bytes[0] = b'Q'; + let err = SaslInitialResponseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, SaslInitialResponseError::UnexpectedTag(_)); + } + + #[test] + fn unexpected_eof_missing_data_len() { + let mut body = BytesMut::new(); + body.extend_from_slice("SCRAM-SHA-256".as_bytes()); + body.put_u8(0); + // missing i32 + let mut frame = BytesMut::new(); + frame.put_u8(b'p'); + frame.put_u32((4 + body.len()) as u32); + frame.extend_from_slice(&body); + let err = SaslInitialResponseFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, SaslInitialResponseError::UnexpectedEof); + } + + #[test] + fn invalid_data_length_negative_not_minus_one() { + let mut body = BytesMut::new(); + body.extend_from_slice("SCRAM-SHA-256".as_bytes()); + body.put_u8(0); + body.put_i32(-2); // invalid + let mut frame = BytesMut::new(); + frame.put_u8(b'p'); + frame.put_u32((4 + body.len()) as u32); + frame.extend_from_slice(&body); + let err = SaslInitialResponseFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, SaslInitialResponseError::InvalidDataLength); + } + + #[test] + fn unexpected_length_data_mismatch() { + let mut body = BytesMut::new(); + body.extend_from_slice("SCRAM-SHA-256".as_bytes()); + body.put_u8(0); + body.put_i32(10); // claims 10 bytes + body.extend_from_slice(b"short"); // only 5 + let mut frame = BytesMut::new(); + frame.put_u8(b'p'); + frame.put_u32((4 + body.len()) as u32); + frame.extend_from_slice(&body); + let err = SaslInitialResponseFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, SaslInitialResponseError::UnexpectedLength(_)); + } + + #[test] + fn extra_data_after_no_data() { + let mut body = BytesMut::new(); + body.extend_from_slice("SCRAM-SHA-256-PLUS".as_bytes()); + body.put_u8(0); + body.put_i32(-1); + body.put_u8(1); // extra + let mut frame = BytesMut::new(); + frame.put_u8(b'p'); + frame.put_u32((4 + body.len()) as u32); + frame.extend_from_slice(&body); + let err = SaslInitialResponseFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, SaslInitialResponseError::UnexpectedLength(_)); + } + + #[test] + fn invalid_utf8_mechanism() { + let mut bytes = make_frame_with_data().to_bytes().unwrap().to_vec(); + bytes[5] = 0xFF; // corrupt mechanism byte + let err = SaslInitialResponseFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, SaslInitialResponseError::Utf8Error(_)); + } + + #[test] + fn invalid_mechanism() { + let mut body = BytesMut::new(); + body.extend_from_slice("PLAIN".as_bytes()); + body.put_u8(0); + body.put_i32(-1); + let mut frame = BytesMut::new(); + frame.put_u8(b'p'); + frame.put_u32((4 + body.len()) as u32); + frame.extend_from_slice(&body); + let err = SaslInitialResponseFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, SaslInitialResponseError::MechanismError(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/sasl_response.rs b/pgdog/src/wire_protocol/frontend/sasl_response.rs new file mode 100644 index 00000000..632c8419 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/sasl_response.rs @@ -0,0 +1,135 @@ +//! Module: wire_protocol::frontend::sasl_response +//! +//! Provides parsing and serialization for the SASLResponse message ('p') in the extended protocol. +//! This is used for continuation responses in SASL authentication. +//! +//! - `SaslResponseFrame`: represents a SASLResponse message carrying a chunk of response data. +//! - `SaslResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `SaslResponseFrame`. +//! Note: This is distinct from SASLInitialResponse or plain password messages, which may use the same tag but different formats. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone)] +pub struct SaslResponseFrame<'a> { + pub data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SaslResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for SaslResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SaslResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + SaslResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {}", len), + } + } +} + +impl StdError for SaslResponseError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for SaslResponseFrame<'a> { + type Error = SaslResponseError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(SaslResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'p' { + return Err(SaslResponseError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(SaslResponseError::UnexpectedLength(len)); + } + + Ok(SaslResponseFrame { data: &bytes[5..] }) + } + + fn to_bytes(&self) -> Result { + let total = 4 + self.data.len(); + let mut buf = BytesMut::with_capacity(1 + total); + buf.put_u8(b'p'); + buf.put_u32(total as u32); + buf.put_slice(self.data); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> SaslResponseFrame<'static> { + SaslResponseFrame { + data: b"bi=rO0ABXNyA", + } // example SCRAM response + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = SaslResponseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.data, frame.data); + } + + #[test] + fn unexpected_tag() { + let mut buf = BytesMut::new(); + buf.put_u8(b'x'); // wrong tag + buf.put_u32(4 + 5); + buf.put_slice(b"test"); + let raw = buf.freeze().to_vec(); + let err = SaslResponseFrame::from_bytes(raw.as_ref()).unwrap_err(); + matches!(err, SaslResponseError::UnexpectedTag(t) if t == b'x'); + } + + #[test] + fn unexpected_length_mismatch() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(10); + buf.put_slice(b"short"); + let raw = buf.freeze().to_vec(); + let err = SaslResponseFrame::from_bytes(raw.as_ref()).unwrap_err(); + matches!(err, SaslResponseError::UnexpectedLength(10)); + } + + #[test] + fn unexpected_length_short_buffer() { + let raw = b"p\x00\x00"; // too short + let err = SaslResponseFrame::from_bytes(raw).unwrap_err(); + matches!(err, SaslResponseError::UnexpectedLength(len) if len == raw.len() as u32); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/ssl_request.rs b/pgdog/src/wire_protocol/frontend/ssl_request.rs new file mode 100644 index 00000000..7ee83955 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/ssl_request.rs @@ -0,0 +1,129 @@ +//! Module: wire_protocol::frontend::ssl_request +//! +//! Provides parsing and serialization for the SSLRequest message in the protocol. +//! +//! Note: Unlike regular protocol messages, SSLRequest has no tag byte and is sent +//! by the client to request an SSL/TLS connection during startup. +//! +//! - `SslRequestFrame`: represents the SSLRequest message. +//! - `SslRequestError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `SslRequestFrame`. + +use crate::wire_protocol::WireSerializable; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SslRequestFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SslRequestError { + UnexpectedLength(usize), + UnexpectedCode(i32), +} + +impl fmt::Display for SslRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SslRequestError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + SslRequestError::UnexpectedCode(code) => write!(f, "unexpected code: {code}"), + } + } +} + +impl StdError for SslRequestError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for SslRequestFrame { + type Error = SslRequestError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() != 8 { + return Err(SslRequestError::UnexpectedLength(bytes.len())); + } + + let mut buf = bytes; + + let len = buf.get_i32(); + if len != 8 { + return Err(SslRequestError::UnexpectedLength(len as usize)); + } + + let code = buf.get_i32(); + if code != 80877103 { + return Err(SslRequestError::UnexpectedCode(code)); + } + + Ok(SslRequestFrame) + } + + fn to_bytes(&self) -> Result { + let mut buf = BytesMut::with_capacity(8); + buf.put_i32(8); + buf.put_i32(80877103); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + 4 // code + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip() { + let frame = SslRequestFrame; + let encoded = frame.to_bytes().unwrap(); + let decoded = SslRequestFrame::from_bytes(encoded.as_ref()).unwrap(); + // no state; just ensure no error + let _ = decoded; + } + + #[test] + fn unexpected_length() { + let mut buf = BytesMut::new(); + buf.put_i32(8); + // missing code + let raw = buf.freeze().to_vec(); + let err = SslRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, SslRequestError::UnexpectedLength(4)); + } + + #[test] + fn unexpected_code() { + let mut buf = BytesMut::new(); + buf.put_i32(8); + buf.put_i32(999999); + let raw = buf.freeze().to_vec(); + let err = SslRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, SslRequestError::UnexpectedCode(999999)); + } + + #[test] + fn unexpected_length_in_message() { + let mut buf = BytesMut::new(); + buf.put_i32(12); // wrong length + buf.put_i32(80877103); + let raw = buf.freeze().to_vec(); + let err = SslRequestFrame::from_bytes(&raw).unwrap_err(); + matches!(err, SslRequestError::UnexpectedLength(12)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/sspi_response.rs b/pgdog/src/wire_protocol/frontend/sspi_response.rs new file mode 100644 index 00000000..0b153e8f --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/sspi_response.rs @@ -0,0 +1,135 @@ +//! Module: wire_protocol::frontend::sspi_response +//! +//! Provides parsing and serialization for the SSPIResponse message ('p') in the extended protocol. +//! This is used for continuation responses in SSPI authentication. +//! +//! - `SspiResponseFrame`: represents a SSPIResponse message carrying a chunk of response data. +//! - `SspiResponseError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `SspiResponseFrame`. +//! Note: This is similar to SASLResponse, using the same tag 'p' but in the context of SSPI auth. + +use bytes::{BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SspiResponseFrame<'a> { + pub data: &'a [u8], +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SspiResponseError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for SspiResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SspiResponseError::UnexpectedTag(t) => write!(f, "unexpected tag: {:#X}", t), + SspiResponseError::UnexpectedLength(len) => write!(f, "unexpected length: {}", len), + } + } +} + +impl StdError for SspiResponseError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for SspiResponseFrame<'a> { + type Error = SspiResponseError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(SspiResponseError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'p' { + return Err(SspiResponseError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len as usize != bytes.len() - 1 { + return Err(SspiResponseError::UnexpectedLength(len)); + } + + Ok(SspiResponseFrame { data: &bytes[5..] }) + } + + fn to_bytes(&self) -> Result { + let total = 4 + self.data.len(); + let mut buf = BytesMut::with_capacity(1 + total); + buf.put_u8(b'p'); + buf.put_u32(total as u32); + buf.put_slice(self.data); + Ok(buf.freeze()) + } + + fn body_size(&self) -> usize { + self.data.len() + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> SspiResponseFrame<'static> { + SspiResponseFrame { + data: b"\x01\x02\x03\x04", + } // example SSPI token + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = SspiResponseFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.data, frame.data); + } + + #[test] + fn unexpected_tag() { + let mut buf = BytesMut::new(); + buf.put_u8(b'x'); // wrong tag + buf.put_u32(4 + 5); + buf.put_slice(b"test"); + let raw = buf.freeze().to_vec(); + let err = SspiResponseFrame::from_bytes(raw.as_ref()).unwrap_err(); + matches!(err, SspiResponseError::UnexpectedTag(t) if t == b'x'); + } + + #[test] + fn unexpected_length_mismatch() { + let mut buf = BytesMut::new(); + buf.put_u8(b'p'); + buf.put_u32(10); + buf.put_slice(b"short"); + let raw = buf.freeze().to_vec(); + let err = SspiResponseFrame::from_bytes(raw.as_ref()).unwrap_err(); + matches!(err, SspiResponseError::UnexpectedLength(10)); + } + + #[test] + fn unexpected_length_short_buffer() { + let raw = b"p\x00\x00"; // too short + let err = SspiResponseFrame::from_bytes(raw).unwrap_err(); + matches!(err, SspiResponseError::UnexpectedLength(len) if len == raw.len() as u32); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/startup.rs b/pgdog/src/wire_protocol/frontend/startup.rs new file mode 100644 index 00000000..d53e3380 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/startup.rs @@ -0,0 +1,228 @@ +//! Module: wire_protocol::frontend::startup +//! +//! Provides parsing and serialization for the Startup message in the PostgreSQL protocol. +//! +//! - `StartupFrame`: represents the initial Startup message with protocol version and connection parameters. +//! - `StartupError`: error types for parsing and encoding. +//! +//! Implements `WireSerializable` for easy conversion between raw bytes and `StartupFrame`. +//! Note: This handles the regular startup message (version typically 196608 for 3.0). Special messages like SSLRequest or CancelRequest are handled separately. + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{error::Error as StdError, fmt, str}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug)] +pub struct StartupFrame<'a> { + pub version: i32, + pub parameters: Vec<(&'a str, &'a str)>, +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum StartupError { + Utf8Error(str::Utf8Error), + UnexpectedEof, + InvalidLength, +} + +impl fmt::Display for StartupError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StartupError::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + StartupError::UnexpectedEof => write!(f, "unexpected EOF"), + StartupError::InvalidLength => write!(f, "invalid length"), + } + } +} + +impl StdError for StartupError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + StartupError::Utf8Error(e) => Some(e), + _ => None, + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Helpers --------------------------------------------------------------- + +fn read_cstr<'a>(buf: &mut &'a [u8]) -> Result<&'a str, StartupError> { + let nul = buf + .iter() + .position(|b| *b == 0) + .ok_or(StartupError::UnexpectedEof)?; + let (raw, rest) = buf.split_at(nul); + *buf = &rest[1..]; // skip NUL + Ok(str::from_utf8(raw).map_err(StartupError::Utf8Error)?) +} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for StartupFrame<'a> { + type Error = StartupError; + + fn from_bytes(mut bytes: &'a [u8]) -> Result { + let initial_remaining = bytes.remaining(); + if initial_remaining < 8 { + return Err(StartupError::UnexpectedEof); + } + + let msg_len = bytes.get_i32() as usize; + if msg_len != initial_remaining { + return Err(StartupError::InvalidLength); + } + + let version = bytes.get_i32(); + + let mut parameters = Vec::new(); + loop { + let key = read_cstr(&mut bytes)?; + if key.is_empty() { + break; + } + let value = read_cstr(&mut bytes)?; + parameters.push((key, value)); + } + + if bytes.has_remaining() { + return Err(StartupError::InvalidLength); + } + + Ok(StartupFrame { + version, + parameters, + }) + } + + fn to_bytes(&self) -> Result { + let mut body = BytesMut::with_capacity(self.body_size()); + + body.put_i32(self.version); + + for &(key, value) in &self.parameters { + body.extend_from_slice(key.as_bytes()); + body.put_u8(0); + body.extend_from_slice(value.as_bytes()); + body.put_u8(0); + } + + body.put_u8(0); + + let mut frame = BytesMut::with_capacity(body.len() + 4); + frame.put_i32((body.len() + 4) as i32); + frame.extend_from_slice(&body); + + Ok(frame.freeze()) + } + + fn body_size(&self) -> usize { + 4 + self + .parameters + .iter() + .map(|(k, v)| k.len() + 1 + v.len() + 1) + .sum::() + + 1 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_frame() -> StartupFrame<'static> { + StartupFrame { + version: 196608, // 3.0 + parameters: vec![("user", "postgres"), ("database", "mydb")], + } + } + + fn make_empty_params_frame() -> StartupFrame<'static> { + StartupFrame { + version: 196608, + parameters: vec![], + } + } + + #[test] + fn roundtrip() { + let frame = make_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = StartupFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.version, frame.version); + assert_eq!(decoded.parameters, frame.parameters); + } + + #[test] + fn roundtrip_empty_params() { + let frame = make_empty_params_frame(); + let encoded = frame.to_bytes().unwrap(); + let decoded = StartupFrame::from_bytes(encoded.as_ref()).unwrap(); + assert_eq!(decoded.version, frame.version); + assert_eq!(decoded.parameters, frame.parameters); + } + + #[test] + fn invalid_length_mismatch() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + // Corrupt length to be larger + let corrupt_len = (bytes.len() + 1) as i32; + bytes[0..4].copy_from_slice(&corrupt_len.to_be_bytes()); + let err = StartupFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, StartupError::InvalidLength); + } + + #[test] + fn unexpected_eof_short_buffer() { + let raw = &[0u8; 4]; // too short + let err = StartupFrame::from_bytes(raw).unwrap_err(); + matches!(err, StartupError::UnexpectedEof); + } + + #[test] + fn unexpected_eof_missing_terminator() { + let mut body = BytesMut::new(); + body.put_i32(196608); + body.extend_from_slice(b"user\0postgres\0database\0mydb"); // missing final \0 + let mut frame = BytesMut::new(); + frame.put_i32((body.len() + 4) as i32); + frame.extend_from_slice(&body); + let err = StartupFrame::from_bytes(frame.as_ref()).unwrap_err(); + matches!(err, StartupError::UnexpectedEof); + } + + #[test] + fn extra_data_after_terminator() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + bytes.push(1); // extra byte + // Adjust length to match new size + let corrupt_len = bytes.len() as i32; + bytes[0..4].copy_from_slice(&corrupt_len.to_be_bytes()); + let err = StartupFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, StartupError::InvalidLength); + } + + #[test] + fn invalid_utf8() { + let mut bytes = make_frame().to_bytes().unwrap().to_vec(); + // Corrupt a byte in "user" + let user_pos = 8; // after length(4) + version(4) + bytes[user_pos] = 0xFF; + let err = StartupFrame::from_bytes(&bytes).unwrap_err(); + matches!(err, StartupError::Utf8Error(_)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/sync.rs b/pgdog/src/wire_protocol/frontend/sync.rs new file mode 100644 index 00000000..93105779 --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/sync.rs @@ -0,0 +1,120 @@ +//! Module: wire_protocol::frontend::sync +//! +//! Provides parsing and serialization for the Sync message ('S') in the extended protocol. +//! +//! - `SyncFrame`: represents a Sync message sent by the client to synchronize after extended query messages. +//! +//! Implements `WireSerializable` for conversion between raw bytes and `SyncFrame`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SyncFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SyncError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for SyncError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SyncError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + SyncError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for SyncError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for SyncFrame { + type Error = SyncError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(SyncError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'S' { + return Err(SyncError::UnexpectedTag(tag)); + } + + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(SyncError::UnexpectedLength(len)); + } + + Ok(SyncFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"S\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_sync() { + let sync = SyncFrame; + let bytes = sync.to_bytes().unwrap(); + let expected_bytes = Bytes::from_static(&[b'S', 0, 0, 0, 4]); + assert_eq!(bytes, expected_bytes); + } + + #[test] + fn deserialize_sync() { + let data = &[b'S', 0, 0, 0, 4][..]; + let sync = SyncFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = sync; + } + + #[test] + fn roundtrip_sync() { + let original = SyncFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = SyncFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = &[b'Q', 0, 0, 0, 4][..]; + let err = SyncFrame::from_bytes(data).unwrap_err(); + matches!(err, SyncError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = &[b'S', 0, 0, 0, 5][..]; + let err = SyncFrame::from_bytes(data).unwrap_err(); + matches!(err, SyncError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/frontend/terminate.rs b/pgdog/src/wire_protocol/frontend/terminate.rs new file mode 100644 index 00000000..fcb4c53d --- /dev/null +++ b/pgdog/src/wire_protocol/frontend/terminate.rs @@ -0,0 +1,118 @@ +//! Module: wire_protocol::frontend::terminate +//! +//! Provides parsing and serialization for the Terminate message ('X') in the extended protocol. +//! +//! - `Terminate`: represents a Terminate message sent by the client to close the connection. +//! +//! Implements `WireSerializable` for conversion between raw bytes and `Terminate`. + +use bytes::Bytes; +use std::{error::Error as StdError, fmt}; + +use crate::wire_protocol::WireSerializable; + +// ----------------------------------------------------------------------------- +// ----- ProtocolMessage ------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TerminateFrame; + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum TerminateError { + UnexpectedTag(u8), + UnexpectedLength(u32), +} + +impl fmt::Display for TerminateError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TerminateError::UnexpectedTag(t) => write!(f, "unexpected tag: {t:#X}"), + TerminateError::UnexpectedLength(len) => write!(f, "unexpected length: {len}"), + } + } +} + +impl StdError for TerminateError {} + +// ----------------------------------------------------------------------------- +// ----- WireSerializable ------------------------------------------------------ + +impl<'a> WireSerializable<'a> for TerminateFrame { + type Error = TerminateError; + + fn from_bytes(bytes: &'a [u8]) -> Result { + if bytes.len() < 5 { + return Err(TerminateError::UnexpectedLength(bytes.len() as u32)); + } + + let tag = bytes[0]; + if tag != b'X' { + return Err(TerminateError::UnexpectedTag(tag)); + } + let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); + if len != 4 { + return Err(TerminateError::UnexpectedLength(len)); + } + Ok(TerminateFrame) + } + + fn to_bytes(&self) -> Result { + Ok(Bytes::from_static(b"X\0\0\0\x04")) + } + + fn body_size(&self) -> usize { + 0 + } +} + +// ----------------------------------------------------------------------------- +// ----- Tests ----------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_terminate() { + let term = TerminateFrame; + let bytes = term.to_bytes().unwrap(); + let expected_bytes = Bytes::from_static(&[b'X', 0, 0, 0, 4]); + assert_eq!(bytes, expected_bytes); + } + + #[test] + fn deserialize_terminate() { + let data = &[b'X', 0, 0, 0, 4][..]; + let term = TerminateFrame::from_bytes(data).unwrap(); + // no state; just ensure no error + let _ = term; + } + + #[test] + fn roundtrip_terminate() { + let original = TerminateFrame; + let bytes = original.to_bytes().unwrap(); + let decoded = TerminateFrame::from_bytes(bytes.as_ref()).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn invalid_tag() { + let data = &[b'Q', 0, 0, 0, 4][..]; + let err = TerminateFrame::from_bytes(data).unwrap_err(); + matches!(err, TerminateError::UnexpectedTag(_)); + } + + #[test] + fn invalid_length() { + let data = &[b'X', 0, 0, 0, 5][..]; + let err = TerminateFrame::from_bytes(data).unwrap_err(); + matches!(err, TerminateError::UnexpectedLength(5)); + } +} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/mod.rs b/pgdog/src/wire_protocol/mod.rs new file mode 100644 index 00000000..7c7e6bc1 --- /dev/null +++ b/pgdog/src/wire_protocol/mod.rs @@ -0,0 +1,9 @@ +pub mod backend; +pub mod bidirectional; +pub mod frontend; +pub mod shared_property_types; +pub mod wire_serializable; + +pub use backend::BackendProtocolMessage; +pub use frontend::FrontendProtocolMessage; +pub use wire_serializable::WireSerializable; diff --git a/pgdog/src/wire_protocol/shared_property_types/mod.rs b/pgdog/src/wire_protocol/shared_property_types/mod.rs new file mode 100644 index 00000000..c5b25dc5 --- /dev/null +++ b/pgdog/src/wire_protocol/shared_property_types/mod.rs @@ -0,0 +1,7 @@ +pub mod parameter; +pub mod result_format; +pub mod sasl_mechanism; + +pub use self::parameter::Parameter; +pub use self::result_format::ResultFormat; +pub use self::sasl_mechanism::{SaslMechanism, SaslMechanismError}; diff --git a/pgdog/src/wire_protocol/shared_property_types/parameter.rs b/pgdog/src/wire_protocol/shared_property_types/parameter.rs new file mode 100644 index 00000000..3140af54 --- /dev/null +++ b/pgdog/src/wire_protocol/shared_property_types/parameter.rs @@ -0,0 +1,16 @@ +#[derive(Debug, PartialEq, Eq)] +pub enum Parameter<'a> { + Text(&'a str), + Binary(&'a [u8]), +} + +impl<'a> Parameter<'a> { + /// Decode this parameter as a UTF-8 string, if possible. + /// Returns `Some(&str)` for text or valid UTF-8 binary, else `None`. + pub fn as_str(&self) -> Option<&'a str> { + match self { + Parameter::Text(s) => Some(s), + Parameter::Binary(b) => std::str::from_utf8(b).ok(), + } + } +} diff --git a/pgdog/src/wire_protocol/shared_property_types/result_format.rs b/pgdog/src/wire_protocol/shared_property_types/result_format.rs new file mode 100644 index 00000000..745254ec --- /dev/null +++ b/pgdog/src/wire_protocol/shared_property_types/result_format.rs @@ -0,0 +1,5 @@ +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ResultFormat { + Text, + Binary, +} diff --git a/pgdog/src/wire_protocol/shared_property_types/sasl_mechanism.rs b/pgdog/src/wire_protocol/shared_property_types/sasl_mechanism.rs new file mode 100644 index 00000000..b10cf167 --- /dev/null +++ b/pgdog/src/wire_protocol/shared_property_types/sasl_mechanism.rs @@ -0,0 +1,54 @@ +use std::error::Error as StdError; +use std::fmt; + +// ----------------------------------------------------------------------------- +// ----- Property -------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SaslMechanism { + ScramSha256, + ScramSha256Plus, +} + +// ----------------------------------------------------------------------------- +// ----- Encoding/Decoding ----------------------------------------------------- + +impl SaslMechanism { + pub fn from_str(s: &str) -> Result { + match s { + "SCRAM-SHA-256" => Ok(Self::ScramSha256), + "SCRAM-SHA-256-PLUS" => Ok(Self::ScramSha256Plus), + other => Err(SaslMechanismError::UnsupportedMechanism(other.to_string())), + } + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::ScramSha256 => "SCRAM-SHA-256", + Self::ScramSha256Plus => "SCRAM-SHA-256-PLUS", + } + } +} + +// ----------------------------------------------------------------------------- +// ----- Error ----------------------------------------------------------------- + +#[derive(Debug)] +pub enum SaslMechanismError { + UnsupportedMechanism(String), +} + +impl fmt::Display for SaslMechanismError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SaslMechanismError::UnsupportedMechanism(m) => { + write!(f, "unsupported SASL mechanism: {m}") + } + } + } +} + +impl StdError for SaslMechanismError {} + +// ----------------------------------------------------------------------------- +// ----------------------------------------------------------------------------- diff --git a/pgdog/src/wire_protocol/wire_serializable.rs b/pgdog/src/wire_protocol/wire_serializable.rs new file mode 100644 index 00000000..43eb84ae --- /dev/null +++ b/pgdog/src/wire_protocol/wire_serializable.rs @@ -0,0 +1,14 @@ +use bytes::Bytes; +use std::error::Error as StdError; + +pub trait WireSerializable<'a>: Sized { + type Error: StdError + Send + Sync + 'static; + + /// Serialize the object into bytes for wire transmission. + fn to_bytes(&self) -> Result; + + /// Deserialize from bytes into the object. + fn from_bytes(bytes: &'a [u8]) -> Result; + + fn body_size(&self) -> usize; +}