Skip to content

Commit 5766390

Browse files
backport IPV6 address parsing
1 parent 8eac3bc commit 5766390

File tree

10 files changed

+504
-275
lines changed

10 files changed

+504
-275
lines changed

src/client/options.rs

Lines changed: 99 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::{
1010
convert::TryFrom,
1111
fmt::{self, Display, Formatter, Write},
1212
hash::{Hash, Hasher},
13+
net::Ipv6Addr,
1314
path::PathBuf,
1415
str::FromStr,
1516
sync::Arc,
@@ -124,9 +125,29 @@ impl<'de> Deserialize<'de> for ServerAddress {
124125
where
125126
D: Deserializer<'de>,
126127
{
127-
let s: String = Deserialize::deserialize(deserializer)?;
128-
Self::parse(s.as_str())
129-
.map_err(|e| <D::Error as serde::de::Error>::custom(format!("{}", e)))
128+
#[derive(Deserialize)]
129+
#[serde(untagged)]
130+
enum ServerAddressHelper {
131+
String(String),
132+
Object { host: String, port: Option<u16> },
133+
}
134+
135+
let helper = ServerAddressHelper::deserialize(deserializer)?;
136+
match helper {
137+
ServerAddressHelper::String(string) => {
138+
Self::parse(string).map_err(serde::de::Error::custom)
139+
}
140+
ServerAddressHelper::Object { host, port } => {
141+
#[cfg(unix)]
142+
if host.ends_with("sock") {
143+
return Ok(Self::Unix {
144+
path: PathBuf::from(host),
145+
});
146+
}
147+
148+
Ok(Self::Tcp { host, port })
149+
}
150+
}
130151
}
131152
}
132153

@@ -184,71 +205,92 @@ impl ServerAddress {
184205
/// Parses an address string into a `ServerAddress`.
185206
pub fn parse(address: impl AsRef<str>) -> Result<Self> {
186207
let address = address.as_ref();
187-
// checks if the address is a unix domain socket
188-
#[cfg(unix)]
189-
{
190-
if address.ends_with(".sock") {
191-
return Ok(ServerAddress::Unix {
208+
209+
if address.ends_with(".sock") {
210+
#[cfg(unix)]
211+
{
212+
let address = percent_decode(address, "unix domain sockets must be URL-encoded")?;
213+
return Ok(Self::Unix {
192214
path: PathBuf::from(address),
193215
});
194216
}
217+
#[cfg(not(unix))]
218+
return Err(ErrorKind::InvalidArgument {
219+
message: "unix domain sockets are not supported on this platform".to_string(),
220+
}
221+
.into());
195222
}
196-
let mut parts = address.split(':');
197-
let hostname = match parts.next() {
198-
Some(part) => {
199-
if part.is_empty() {
200-
return Err(ErrorKind::InvalidArgument {
201-
message: format!(
202-
"invalid server address: \"{}\"; hostname cannot be empty",
203-
address
204-
),
205-
}
206-
.into());
223+
224+
let (hostname, port) = if let Some(ip_literal) = address.strip_prefix("[") {
225+
let Some((hostname, port)) = ip_literal.split_once("]") else {
226+
return Err(ErrorKind::InvalidArgument {
227+
message: format!(
228+
"invalid server address {}: missing closing ']' in IP literal hostname",
229+
address
230+
),
207231
}
208-
part
209-
}
210-
None => {
232+
.into());
233+
};
234+
235+
if let Err(parse_error) = Ipv6Addr::from_str(hostname) {
211236
return Err(ErrorKind::InvalidArgument {
212-
message: format!("invalid server address: \"{}\"", address),
237+
message: format!("invalid server address {}: {}", address, parse_error),
213238
}
214-
.into())
239+
.into());
215240
}
216-
};
217241

218-
let port = match parts.next() {
219-
Some(part) => {
220-
let port = u16::from_str(part).map_err(|_| ErrorKind::InvalidArgument {
242+
let port = if port.is_empty() {
243+
None
244+
} else if let Some(port) = port.strip_prefix(":") {
245+
Some(port)
246+
} else {
247+
return Err(ErrorKind::InvalidArgument {
221248
message: format!(
222-
"port must be valid 16-bit unsigned integer, instead got: {}",
223-
part
249+
"invalid server address {}: the hostname can only be followed by a port \
250+
prefixed with ':', got {}",
251+
address, port
224252
),
225-
})?;
226-
227-
if port == 0 {
228-
return Err(ErrorKind::InvalidArgument {
229-
message: format!(
230-
"invalid server address: \"{}\"; port must be non-zero",
231-
address
232-
),
233-
}
234-
.into());
235253
}
236-
if parts.next().is_some() {
254+
.into());
255+
};
256+
257+
(hostname, port)
258+
} else {
259+
match address.split_once(":") {
260+
Some((hostname, port)) => (hostname, Some(port)),
261+
None => (address, None),
262+
}
263+
};
264+
265+
if hostname.is_empty() {
266+
return Err(ErrorKind::InvalidArgument {
267+
message: format!(
268+
"invalid server address {}: the hostname cannot be empty",
269+
address
270+
),
271+
}
272+
.into());
273+
}
274+
275+
let port = if let Some(port) = port {
276+
match u16::from_str(port) {
277+
Ok(0) | Err(_) => {
237278
return Err(ErrorKind::InvalidArgument {
238279
message: format!(
239-
"address \"{}\" contains more than one unescaped ':'",
240-
address
280+
"invalid server address {}: the port must be an integer between 1 and \
281+
65535, got {}",
282+
address, port
241283
),
242284
}
243-
.into());
285+
.into())
244286
}
245-
246-
Some(port)
287+
Ok(port) => Some(port),
247288
}
248-
None => None,
289+
} else {
290+
None
249291
};
250292

251-
Ok(ServerAddress::Tcp {
293+
Ok(Self::Tcp {
252294
host: hostname.to_lowercase(),
253295
port,
254296
})
@@ -1689,37 +1731,21 @@ impl ConnectionString {
16891731
None => (None, None),
16901732
};
16911733

1692-
let mut host_list = Vec::with_capacity(hosts_section.len());
1693-
for host in hosts_section.split(',') {
1694-
let address = if host.ends_with(".sock") {
1695-
#[cfg(unix)]
1696-
{
1697-
ServerAddress::parse(percent_decode(
1698-
host,
1699-
"Unix domain sockets must be URL-encoded",
1700-
)?)
1701-
}
1702-
#[cfg(not(unix))]
1703-
return Err(ErrorKind::InvalidArgument {
1704-
message: "Unix domain sockets are not supported on this platform".to_string(),
1705-
}
1706-
.into());
1707-
} else {
1708-
ServerAddress::parse(host)
1709-
}?;
1710-
host_list.push(address);
1711-
}
1734+
let hosts = hosts_section
1735+
.split(',')
1736+
.map(ServerAddress::parse)
1737+
.collect::<Result<Vec<ServerAddress>>>()?;
17121738

1713-
let hosts = if srv {
1714-
if host_list.len() != 1 {
1739+
let host_info = if srv {
1740+
if hosts.len() != 1 {
17151741
return Err(ErrorKind::InvalidArgument {
17161742
message: "exactly one host must be specified with 'mongodb+srv'".into(),
17171743
}
17181744
.into());
17191745
}
17201746

17211747
// Unwrap safety: the `len` check above guarantees this can't fail.
1722-
match host_list.into_iter().next().unwrap() {
1748+
match hosts.into_iter().next().unwrap() {
17231749
ServerAddress::Tcp { host, port } => {
17241750
if port.is_some() {
17251751
return Err(ErrorKind::InvalidArgument {
@@ -1738,11 +1764,11 @@ impl ConnectionString {
17381764
}
17391765
}
17401766
} else {
1741-
HostInfo::HostIdentifiers(host_list)
1767+
HostInfo::HostIdentifiers(hosts)
17421768
};
17431769

17441770
let mut conn_str = ConnectionString {
1745-
host_info: hosts,
1771+
host_info,
17461772
#[cfg(test)]
17471773
original_uri: s.into(),
17481774
..Default::default()

0 commit comments

Comments
 (0)