From 398d8a1d5de6033dc461c3b6c4cc4eaecd1212bd Mon Sep 17 00:00:00 2001 From: Param Dhanoya Date: Fri, 8 Nov 2024 19:28:59 -0800 Subject: [PATCH] Added support to parse ipv6 address --- src/client/options.rs | 97 +++++++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/src/client/options.rs b/src/client/options.rs index c58be4e4c..518b93a01 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -23,6 +23,7 @@ use serde::{de::Unexpected, Deserialize, Deserializer, Serialize}; use serde_with::skip_serializing_none; use strsim::jaro_winkler; use typed_builder::TypedBuilder; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; #[cfg(any( feature = "zstd-compression", @@ -197,64 +198,78 @@ impl ServerAddress { }); } } - let mut parts = address.split(':'); - let hostname = match parts.next() { - Some(part) => { - if part.is_empty() { + + // Check if the address is IPv6, indicated by square brackets + let (hostname, port_str) = if address.starts_with('[') { + match address.rfind(']') { + Some(bracket_pos) => { + let addr_without_brackets = &address[1..bracket_pos]; + let port = if bracket_pos + 1 < address.len() { + if !address[bracket_pos + 1..].starts_with(':') { + return Err(ErrorKind::InvalidArgument { + message: format!("invalid IPv6 address format: expected ':' after ']' in \"{}\"", address), + }.into()); + } + Some(&address[bracket_pos + 2..]) + } else { + None + }; + (addr_without_brackets, port) + }, + None => { return Err(ErrorKind::InvalidArgument { - message: format!( - "invalid server address: \"{}\"; hostname cannot be empty", - address - ), - } - .into()); + message: format!("invalid IPv6 address format: missing closing bracket in \"{}\"", address), + }.into()); } - part } - None => { - return Err(ErrorKind::InvalidArgument { - message: format!("invalid server address: \"{}\"", address), + } else { + match address.rsplit_once(':') { + Some((host, port)) => (host, Some(port)), + None => { + return Err(ErrorKind::InvalidArgument { + message: format!("invalid server address: \"{}\"; port is required", address), + }.into()) } - .into()) } }; - let port = match parts.next() { - Some(part) => { - let port = u16::from_str(part).map_err(|_| ErrorKind::InvalidArgument { - message: format!( - "port must be valid 16-bit unsigned integer, instead got: {}", - part - ), - })?; + // Validate that the hostname is either a valid IPv4 or IPv6 address + let is_ipv4 = Ipv4Addr::from_str(hostname).is_ok(); + let is_ipv6 = Ipv6Addr::from_str(hostname).is_ok(); + let localhost_v4 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + let localhost_v6 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + if !is_ipv4 && !is_ipv6 && "127.0.0.1".parse() != Ok(localhost_v4) && "::1".parse() != Ok(localhost_v6) { + return Err(ErrorKind::InvalidArgument { + message: format!("invalid hostname: \"{}\"", hostname), + } + .into()) + } - if port == 0 { + // Validate port + let port = match port_str { + None => { + return Err(ErrorKind::InvalidArgument { + message: format!("invalid server address: \"{}\"; port is required", address), + }.into()) + } + Some(port_str) => match port_str.parse::() { + Ok(0) => { return Err(ErrorKind::InvalidArgument { - message: format!( - "invalid server address: \"{}\"; port must be non-zero", - address - ), - } - .into()); + message: format!("invalid server address: \"{}\"; port cannot be 0", address), + }.into()) } - if parts.next().is_some() { + Ok(port) => port, + Err(_) => { return Err(ErrorKind::InvalidArgument { - message: format!( - "address \"{}\" contains more than one unescaped ':'", - address - ), - } - .into()); + message: format!("invalid server address: \"{}\"; invalid port number", address), + }.into()) } - - Some(port) } - None => None, }; Ok(ServerAddress::Tcp { host: hostname.to_lowercase(), - port, + port: Some(port), }) }