diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index f2200a40c..19aa3c1e9 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -117,7 +117,7 @@ enum Credentials { /// A regular password as a vector of bytes. Password(Vec), /// A precomputed pair of keys. - Keys(ScramKeys), + Keys(Box>), } enum State { @@ -176,7 +176,7 @@ impl ScramSha256 { /// Constructs a new instance which will use the provided key pair for authentication. pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { - let password = Credentials::Keys(keys); + let password = Credentials::Keys(keys.into()); ScramSha256::new_inner(password, channel_binding, nonce()) } diff --git a/postgres-protocol/src/message/frontend.rs b/postgres-protocol/src/message/frontend.rs index dabed0bab..5d0a8ff8c 100644 --- a/postgres-protocol/src/message/frontend.rs +++ b/postgres-protocol/src/message/frontend.rs @@ -271,66 +271,6 @@ where }) } -#[inline] -pub fn startup_message_cstr( - parameters: &StartupMessageParams, - buf: &mut BytesMut, -) -> io::Result<()> { - write_body(buf, |buf| { - // postgres protocol version 3.0(196608) in bigger-endian - buf.put_i32(0x00_03_00_00); - buf.put_slice(¶meters.params); - buf.put_u8(0); - Ok(()) - }) -} - -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct StartupMessageParams { - params: BytesMut, -} - -impl StartupMessageParams { - /// Set parameter's value by its name. - pub fn insert(&mut self, name: &str, value: &str) -> Result<(), io::Error> { - if name.contains('\0') | value.contains('\0') { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", - )); - } - self.params.put(name.as_bytes()); - self.params.put(&b"\0"[..]); - self.params.put(value.as_bytes()); - self.params.put(&b"\0"[..]); - Ok(()) - } - - pub fn str_iter(&self) -> impl Iterator { - let params = - std::str::from_utf8(&self.params).expect("should be validated as utf8 already"); - StrParamsIter(params) - } - - /// Get parameter's value by its name. - pub fn get(&self, name: &str) -> Option<&str> { - self.str_iter().find_map(|(k, v)| (k == name).then_some(v)) - } -} - -struct StrParamsIter<'a>(&'a str); - -impl<'a> Iterator for StrParamsIter<'a> { - type Item = (&'a str, &'a str); - - fn next(&mut self) -> Option { - let (key, r) = self.0.split_once('\0')?; - let (value, r) = r.split_once('\0')?; - self.0 = r; - Some((key, value)) - } -} - #[inline] pub fn sync(buf: &mut BytesMut) { buf.put_u8(b'S'); diff --git a/postgres/src/config.rs b/postgres/src/config.rs index ccbbe7c51..44e4bec3a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -145,6 +145,12 @@ impl Config { self } + /// Gets the password to authenticate with, if one has been configured with + /// the `password` method. + pub fn get_password(&self) -> Option<&[u8]> { + self.config.get_password() + } + /// Sets precomputed protocol-specific keys to authenticate with. /// When set, this option will override `password`. /// See [`AuthKeys`] for more information. @@ -153,6 +159,12 @@ impl Config { self } + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.config.get_auth_keys() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. @@ -173,12 +185,24 @@ impl Config { self } + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.config.get_options() + } + /// Sets the value of the `application_name` runtime parameter. pub fn application_name(&mut self, application_name: &str) -> &mut Config { self.config.application_name(application_name); self } + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.config.get_application_name() + } + /// Sets the SSL configuration. /// /// Defaults to `prefer`. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index f6cff7bb0..fdb5e6359 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -10,7 +10,6 @@ use crate::tls::TlsConnect; #[cfg(feature = "runtime")] use crate::Socket; use crate::{Client, Connection, Error}; -use postgres_protocol::message::frontend::StartupMessageParams; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; @@ -171,7 +170,12 @@ pub enum AuthKeys { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - pub(crate) auth: Option, + pub(crate) user: Option, + pub(crate) password: Option>, + pub(crate) auth_keys: Option>, + pub(crate) dbname: Option, + pub(crate) options: Option, + pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, pub(crate) port: Vec, @@ -180,18 +184,8 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, pub(crate) max_backend_message_size: Option, - pub(crate) server_settings: StartupMessageParams, -} - -#[derive(Clone, PartialEq, Eq)] -#[non_exhaustive] -/// What auth info to use when authenticating -pub enum Auth { - /// password based auth - Password(Vec), - /// precomputed scram based auth - AuthKeys(AuthKeys), } impl Default for Config { @@ -209,7 +203,12 @@ impl Config { retries: None, }; Config { - auth: None, + user: None, + password: None, + auth_keys: None, + dbname: None, + options: None, + application_name: None, ssl_mode: SslMode::Prefer, host: vec![], port: vec![], @@ -218,8 +217,8 @@ impl Config { keepalive_config, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + replication_mode: None, max_backend_message_size: None, - server_settings: StartupMessageParams::default(), } } @@ -227,14 +226,14 @@ impl Config { /// /// Required. pub fn user(&mut self, user: &str) -> &mut Config { - self.server_settings.insert("user", user).unwrap(); + self.user = Some(user.to_string()); self } /// Gets the user to authenticate with, if one has been configured with /// the `user` method. pub fn get_user(&self) -> Option<&str> { - self.server_settings.get("user") + self.user.as_deref() } /// Sets the password to authenticate with. @@ -242,60 +241,68 @@ impl Config { where T: AsRef<[u8]>, { - self.auth = Some(Auth::Password(password.as_ref().to_vec())); + self.password = Some(password.as_ref().to_vec()); self } /// Gets the password to authenticate with, if one has been configured with /// the `password` method. - pub fn get_auth(&self) -> Option { - self.auth.clone() + pub fn get_password(&self) -> Option<&[u8]> { + self.password.as_deref() } /// Sets precomputed protocol-specific keys to authenticate with. /// When set, this option will override `password`. /// See [`AuthKeys`] for more information. - pub fn auth(&mut self, keys: Auth) -> &mut Config { - self.auth = Some(keys); + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); self } - /// Sets precomputed protocol-specific keys to authenticate with. - /// When set, this option will override `password`. - /// See [`AuthKeys`] for more information. - pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { - self.auth = Some(Auth::AuthKeys(keys)); - self + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() } /// Sets the name of the database to connect to. /// /// Defaults to the user. pub fn dbname(&mut self, dbname: &str) -> &mut Config { - self.server_settings.insert("database", dbname).unwrap(); + self.dbname = Some(dbname.to_string()); self } /// Gets the name of the database to connect to, if one has been configured /// with the `dbname` method. pub fn get_dbname(&self) -> Option<&str> { - self.server_settings.get("database") + self.dbname.as_deref() } /// Sets command line options used to configure the server. pub fn options(&mut self, options: &str) -> &mut Config { - self.server_settings.insert("options", options).unwrap(); + self.options = Some(options.to_string()); self } + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.options.as_deref() + } + /// Sets the value of the `application_name` runtime parameter. pub fn application_name(&mut self, application_name: &str) -> &mut Config { - self.server_settings - .insert("application_name", application_name) - .unwrap(); + self.application_name = Some(application_name.to_string()); self } + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -458,18 +465,15 @@ impl Config { /// Set replication mode. pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { - match replication_mode { - ReplicationMode::Physical => { - self.server_settings.insert("replication", "true").unwrap() - } - ReplicationMode::Logical => self - .server_settings - .insert("replication", "database") - .unwrap(), - } + self.replication_mode = Some(replication_mode); self } + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + /// Set limit for backend messages size. pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { self.max_backend_message_size = Some(max_backend_message_size); @@ -481,8 +485,7 @@ impl Config { self.max_backend_message_size } - /// Set an arbitrary param - pub fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { self.user(value); @@ -493,6 +496,12 @@ impl Config { "dbname" => { self.dbname(value); } + "options" => { + self.options(value); + } + "application_name" => { + self.application_name(value); + } "sslmode" => { let mode = match value { "disable" => SslMode::Disable, @@ -579,6 +588,17 @@ impl Config { }; self.channel_binding(channel_binding); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } "max_backend_message_size" => { let limit = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) @@ -588,9 +608,9 @@ impl Config { } } key => { - self.server_settings - .insert(key, value) - .map_err(|e| Error::config_parse(e.into()))?; + return Err(Error::config_parse(Box::new(UnknownOption( + key.to_string(), + )))); } } @@ -645,8 +665,12 @@ impl fmt::Debug for Config { } } - let mut f = f.debug_struct("Config"); - f.field("auth", &self.auth.as_ref().map(|_| Redaction {})) + f.debug_struct("Config") + .field("user", &self.user) + .field("password", &self.password.as_ref().map(|_| Redaction {})) + .field("dbname", &self.dbname) + .field("options", &self.options) + .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) .field("port", &self.port) @@ -656,16 +680,23 @@ impl fmt::Debug for Config { .field("keepalives_interval", &self.keepalive_config.interval) .field("keepalives_retries", &self.keepalive_config.retries) .field("target_session_attrs", &self.target_session_attrs) - .field("channel_binding", &self.channel_binding); + .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) + .finish() + } +} - for (k, v) in self.server_settings.str_iter() { - f.field(k, &v); - } +#[derive(Debug)] +struct UnknownOption(String); - f.finish() +impl fmt::Display for UnknownOption { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "unknown option `{}`", self.0) } } +impl error::Error for UnknownOption {} + #[derive(Debug)] struct InvalidValue(&'static str); diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 4d3f58b78..8e788984a 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Auth, AuthKeys, Config}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -116,14 +116,28 @@ where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { - // leave for user to provide: - // let mut params = config.server_settings.clone(); - // params - // .insert("client_encoding", "UTF8") - // .map_err(Error::encode)?; + let mut params = vec![("client_encoding", "UTF8")]; + if let Some(user) = &config.user { + params.push(("user", &**user)); + } + if let Some(dbname) = &config.dbname { + params.push(("database", &**dbname)); + } + if let Some(options) = &config.options { + params.push(("options", &**options)); + } + if let Some(application_name) = &config.application_name { + params.push(("application_name", &**application_name)); + } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); - frontend::startup_message_cstr(&config.server_settings, &mut buf).map_err(Error::encode)?; + frontend::startup_message(params, &mut buf).map_err(Error::encode)?; stream .send(FrontendMessage::Raw(buf.freeze())) @@ -144,25 +158,27 @@ where Some(Message::AuthenticationCleartextPassword) => { can_skip_channel_binding(config)?; - match &config.auth { - Some(Auth::Password(pass)) => authenticate_password(stream, pass).await?, - _ => return Err(Error::config("password missing".into())), - } + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + authenticate_password(stream, pass).await?; } Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; let user = config - .get_user() + .user + .as_ref() .ok_or_else(|| Error::config("user missing".into()))?; + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; - match &config.auth { - Some(Auth::Password(pass)) => { - let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); - authenticate_password(stream, output.as_bytes()).await?; - } - _ => return Err(Error::config("password missing".into())), - } + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); + authenticate_password(stream, output.as_bytes()).await?; } Some(Message::AuthenticationSasl(body)) => { authenticate_sasl(stream, body, config).await?; @@ -260,12 +276,12 @@ where can_skip_channel_binding(config)?; } - let mut scram = match &config.auth { - Some(Auth::AuthKeys(AuthKeys::ScramSha256(keys))) => { - ScramSha256::new_with_keys(*keys, channel_binding) - } - Some(Auth::Password(password)) => ScramSha256::new(password, channel_binding), - None => return Err(Error::config("password or auth keys missing".into())), + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); }; let mut buf = BytesMut::new(); diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index c074bb0d1..772612de6 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -166,20 +166,6 @@ async fn pipelined_prepare() { assert_eq!(statement2.columns()[0].type_(), &Type::INT8); } -// regression: https://github.com/neondatabase/neon/issues/1287#issuecomment-1251922486 -#[tokio::test] -#[cfg(feature = "with-serde_json-1")] -async fn custom_params() { - let client = connect("user=postgres IntervalStyle=iso_8601").await; - - let row = client - .query_one("select to_json('0 seconds'::interval)", &[]) - .await - .unwrap(); - - assert_eq!(row.get::<_, serde_json_1::Value>(0), "PT0S"); -} - #[tokio::test] async fn insert_select() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs index b510d8879..c176a4104 100644 --- a/tokio-postgres/tests/test/replication.rs +++ b/tokio-postgres/tests/test/replication.rs @@ -10,7 +10,6 @@ use tokio_postgres::NoTls; use tokio_postgres::SimpleQueryMessage::Row; #[tokio::test] -#[ignore = "replication"] async fn test_replication() { // form SQL connection let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database";