Skip to content

Commit b1490b5

Browse files
RUST-2109 Fix comparison of IPv6 addresses when updating the topology (#1254)
1 parent bcff155 commit b1490b5

File tree

6 files changed

+73
-38
lines changed

6 files changed

+73
-38
lines changed

src/sdam/description/server.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ impl PartialEq for ServerDescription {
192192
}
193193

194194
impl ServerDescription {
195-
pub(crate) fn new(address: ServerAddress) -> Self {
195+
pub(crate) fn new(address: &ServerAddress) -> Self {
196196
Self {
197197
address: match address {
198198
ServerAddress::Tcp { host, port } => ServerAddress::Tcp {
199199
host: host.to_lowercase(),
200-
port,
200+
port: *port,
201201
},
202202
#[cfg(unix)]
203-
ServerAddress::Unix { path } => ServerAddress::Unix { path },
203+
ServerAddress::Unix { path } => ServerAddress::Unix { path: path.clone() },
204204
},
205205
server_type: Default::default(),
206206
last_update_time: None,
@@ -214,7 +214,7 @@ impl ServerDescription {
214214
mut reply: HelloReply,
215215
average_rtt: Duration,
216216
) -> Self {
217-
let mut description = Self::new(address);
217+
let mut description = Self::new(&address);
218218
description.average_round_trip_time = Some(average_rtt);
219219
description.last_update_time = Some(DateTime::now());
220220

@@ -259,7 +259,7 @@ impl ServerDescription {
259259
}
260260

261261
pub(crate) fn new_from_error(address: ServerAddress, error: Error) -> Self {
262-
let mut description = Self::new(address);
262+
let mut description = Self::new(&address);
263263
description.last_update_time = Some(DateTime::now());
264264
description.average_round_trip_time = None;
265265
description.reply = Err(error);
@@ -310,7 +310,7 @@ impl ServerDescription {
310310
Ok(set_name)
311311
}
312312

313-
pub(crate) fn known_hosts(&self) -> Result<impl Iterator<Item = &String>> {
313+
pub(crate) fn known_hosts(&self) -> Result<Vec<ServerAddress>> {
314314
let known_hosts = self
315315
.reply
316316
.as_ref()
@@ -328,7 +328,11 @@ impl ServerDescription {
328328
.chain(arbiters.into_iter().flatten())
329329
});
330330

331-
Ok(known_hosts.into_iter().flatten())
331+
known_hosts
332+
.into_iter()
333+
.flatten()
334+
.map(ServerAddress::parse)
335+
.collect()
332336
}
333337

334338
pub(crate) fn invalid_me(&self) -> Result<bool> {

src/sdam/description/topology.rs

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ impl TopologyDescription {
170170
};
171171

172172
for address in options.hosts.iter() {
173-
let description = ServerDescription::new(address.clone());
173+
let description = ServerDescription::new(address);
174174
self.servers.insert(address.to_owned(), description);
175175
}
176176

@@ -387,7 +387,7 @@ impl TopologyDescription {
387387
let mut new = vec![];
388388
for host in hosts {
389389
if !self.servers.contains_key(&host) {
390-
new.push((host.clone(), ServerDescription::new(host)));
390+
new.push((host.clone(), ServerDescription::new(&host)));
391391
}
392392
}
393393
if let Some(max) = self.srv_max_hosts {
@@ -599,7 +599,7 @@ impl TopologyDescription {
599599
return Ok(());
600600
}
601601

602-
self.add_new_servers(server_description.known_hosts()?)?;
602+
self.add_new_servers(server_description.known_hosts()?);
603603

604604
if server_description.invalid_me()? {
605605
self.servers.remove(&server_description.address);
@@ -655,7 +655,7 @@ impl TopologyDescription {
655655
{
656656
self.servers.insert(
657657
server_description.address.clone(),
658-
ServerDescription::new(server_description.address),
658+
ServerDescription::new(&server_description.address),
659659
);
660660
self.record_primary_state();
661661
return Ok(());
@@ -688,16 +688,16 @@ impl TopologyDescription {
688688
}
689689

690690
if let ServerType::RsPrimary = self.servers.get(&address).unwrap().server_type {
691-
self.servers
692-
.insert(address.clone(), ServerDescription::new(address));
691+
let description = ServerDescription::new(&address);
692+
self.servers.insert(address, description);
693693
}
694694
}
695695

696-
self.add_new_servers(server_description.known_hosts()?)?;
697-
let known_hosts: HashSet<_> = server_description.known_hosts()?.collect();
696+
let known_hosts = server_description.known_hosts()?;
697+
self.add_new_servers(known_hosts.clone());
698698

699699
for address in addresses {
700-
if !known_hosts.contains(&address.to_string()) {
700+
if !known_hosts.contains(&address) {
701701
self.servers.remove(&address);
702702
}
703703
}
@@ -724,23 +724,11 @@ impl TopologyDescription {
724724
}
725725

726726
/// Create a new ServerDescription for each address and add it to the topology.
727-
fn add_new_servers<'a>(&mut self, servers: impl Iterator<Item = &'a String>) -> Result<()> {
728-
let servers: Result<Vec<_>> = servers.map(ServerAddress::parse).collect();
729-
730-
self.add_new_servers_from_addresses(servers?.iter());
731-
Ok(())
732-
}
733-
734-
/// Create a new ServerDescription for each address and add it to the topology.
735-
fn add_new_servers_from_addresses<'a>(
736-
&mut self,
737-
servers: impl Iterator<Item = &'a ServerAddress>,
738-
) {
739-
for server in servers {
740-
if !self.servers.contains_key(server) {
741-
self.servers
742-
.insert(server.clone(), ServerDescription::new(server.clone()));
743-
}
727+
fn add_new_servers(&mut self, addresses: impl IntoIterator<Item = ServerAddress>) {
728+
for address in addresses {
729+
self.servers
730+
.entry(address.clone())
731+
.or_insert_with(|| ServerDescription::new(&address));
744732
}
745733
}
746734
}

src/sdam/description/topology/server_selection/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ impl TestServerDescription {
103103
reply,
104104
avg_rtt_ms.map(f64_ms_as_duration).unwrap(),
105105
),
106-
None => ServerDescription::new(server_address),
106+
None => ServerDescription::new(&server_address),
107107
};
108108
server_desc.last_update_time = self
109109
.last_update_time

src/sdam/topology.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ impl TopologyWorker {
335335
self.update_topology(new_description).await;
336336

337337
if self.options.load_balanced == Some(true) {
338-
let base = ServerDescription::new(self.options.hosts[0].clone());
338+
let base = ServerDescription::new(&self.options.hosts[0]);
339339
self.update_server(ServerDescription {
340340
server_type: ServerType::LoadBalancer,
341341
average_round_trip_time: None,
@@ -374,7 +374,9 @@ impl TopologyWorker {
374374
UpdateMessage::SyncHosts(hosts) => {
375375
self.sync_hosts(hosts).await
376376
}
377-
UpdateMessage::ServerUpdate(sd) => self.update_server(*sd).await,
377+
UpdateMessage::ServerUpdate(sd) => {
378+
self.update_server(*sd).await
379+
}
378380
UpdateMessage::MonitorError { address, error } => {
379381
self.handle_monitor_error(address, error).await
380382
}

src/test/client.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{borrow::Cow, collections::HashMap, future::IntoFuture, time::Duration};
1+
use std::{borrow::Cow, collections::HashMap, future::IntoFuture, net::Ipv6Addr, time::Duration};
22

33
use bson::Document;
44
use serde::{Deserialize, Serialize};
@@ -982,3 +982,44 @@ async fn end_sessions_on_shutdown() {
982982
client2.into_client().shutdown().await;
983983
assert_eq!(get_end_session_event_count(&mut event_stream).await, 0);
984984
}
985+
986+
#[tokio::test]
987+
async fn ipv6_connect() {
988+
let ipv6_localhost = Ipv6Addr::LOCALHOST.to_string();
989+
990+
let client = Client::for_test().await;
991+
// The hello command returns the hostname as "localhost". However, whatsmyuri returns an
992+
// IP-literal, which allows us to detect whether we can re-construct the client with an IPv6
993+
// address.
994+
let is_ipv6_localhost = client
995+
.database("admin")
996+
.run_command(doc! { "whatsmyuri": 1 })
997+
.await
998+
.ok()
999+
.and_then(|response| {
1000+
response
1001+
.get_str("you")
1002+
.ok()
1003+
.map(|you| you.contains(&ipv6_localhost))
1004+
})
1005+
.unwrap_or(false);
1006+
if !is_ipv6_localhost {
1007+
log_uncaptured("skipping ipv6_connect due to non-ipv6-localhost configuration");
1008+
return;
1009+
}
1010+
1011+
let mut options = get_client_options().await.clone();
1012+
for address in options.hosts.iter_mut() {
1013+
if let ServerAddress::Tcp { host, .. } = address {
1014+
*host = ipv6_localhost.clone();
1015+
}
1016+
}
1017+
let client = Client::with_options(options).unwrap();
1018+
1019+
let result = client
1020+
.database("admin")
1021+
.run_command(doc! { "ping": 1 })
1022+
.await
1023+
.unwrap();
1024+
assert_eq!(result.get_f64("ok"), Ok(1.0));
1025+
}

src/test/spec/trace.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ fn topology_description_tracing_representation() {
459459
let mut servers = HashMap::new();
460460
servers.insert(
461461
ServerAddress::default(),
462-
ServerDescription::new(ServerAddress::default()),
462+
ServerDescription::new(&ServerAddress::default()),
463463
);
464464

465465
let oid = bson::oid::ObjectId::new();

0 commit comments

Comments
 (0)