Skip to content

Commit 09d4ce6

Browse files
authored
RUST-104 Poll SRV records for mongos discovery (#161)
1 parent c747515 commit 09d4ce6

File tree

9 files changed

+483
-46
lines changed

9 files changed

+483
-46
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ sha-1 = "0.8.1"
3838
sha2 = "0.8.0"
3939
stringprep = "0.1.2"
4040
time = "0.1.42"
41-
trust-dns-proto = "0.19.0"
41+
trust-dns-proto = "0.19.4"
4242
trust-dns-resolver = "0.19.0"
4343
typed-builder = "0.3.0"
4444
version_check = "0.9.1"

src/client/options/mod.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ pub struct ClientOptions {
304304
#[builder(default)]
305305
pub(crate) zlib_compression: Option<i32>,
306306

307+
#[builder(default)]
308+
original_srv_hostname: Option<String>,
309+
307310
#[builder(default)]
308311
original_uri: Option<String>,
309312
}
@@ -492,12 +495,22 @@ impl From<ClientOptionsParser> for ClientOptions {
492495
credential: parser.credential,
493496
cmap_event_handler: None,
494497
command_event_handler: None,
498+
original_srv_hostname: None,
495499
original_uri: Some(parser.original_uri),
496500
}
497501
}
498502
}
499503

500504
impl ClientOptions {
505+
/// Creates a new ClientOptions with the `original_srv_hostname` field set to the testing value
506+
/// used in the SRV tests.
507+
#[cfg(test)]
508+
pub(crate) fn new_srv() -> Self {
509+
let mut options = Self::default();
510+
options.original_srv_hostname = Some("localhost.test.test.build.10gen.cc".into());
511+
options
512+
}
513+
501514
/// Parses a MongoDB connection string into a ClientOptions struct. If the string is malformed
502515
/// or one of the options has an invalid value, an error will be returned.
503516
///
@@ -555,11 +568,14 @@ impl ClientOptions {
555568
let mut options: Self = parser.into();
556569

557570
if srv {
558-
let resolver = SrvResolver::new().await?;
571+
let mut resolver = SrvResolver::new().await?;
559572
let mut config = resolver
560573
.resolve_client_options(&options.hosts[0].hostname)
561574
.await?;
562575

576+
// Save the original SRV hostname to allow mongos polling.
577+
options.original_srv_hostname = Some(options.hosts[0].hostname.clone());
578+
563579
// Set the ClientOptions hosts to those found during the SRV lookup.
564580
options.hosts = config.hosts;
565581

@@ -593,6 +609,11 @@ impl ClientOptions {
593609
Ok(options)
594610
}
595611

612+
/// Gets the original SRV hostname specified when this ClientOptions was parsed from a URI.
613+
pub(crate) fn original_srv_hostname(&self) -> Option<&String> {
614+
self.original_srv_hostname.as_ref()
615+
}
616+
596617
pub(crate) fn tls_options(&self) -> Option<TlsOptions> {
597618
match self.tls {
598619
Some(Tls::Enabled(ref opts)) => Some(opts.clone()),

src/sdam/description/topology/mod.rs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,27 @@ impl PartialEq for TopologyDescription {
8484
}
8585

8686
impl TopologyDescription {
87+
/// Creates a new TopologyDescription with the set of servers initialized to the addresses
88+
/// specified in `hosts` and each other field set to its default value.
89+
#[cfg(test)]
90+
pub(crate) fn new_from_hosts(hosts: Vec<StreamAddress>) -> Self {
91+
Self {
92+
single_seed: false,
93+
topology_type: TopologyType::Unknown,
94+
set_name: None,
95+
max_set_version: None,
96+
max_election_id: None,
97+
compatibility_error: None,
98+
logical_session_timeout_minutes: None,
99+
local_threshold: None,
100+
heartbeat_freq: None,
101+
servers: hosts
102+
.into_iter()
103+
.map(|address| (address.clone(), ServerDescription::new(address, None)))
104+
.collect(),
105+
}
106+
}
107+
87108
pub(crate) fn new(options: ClientOptions) -> Result<Self> {
88109
verify_max_staleness(
89110
options
@@ -124,6 +145,11 @@ impl TopologyDescription {
124145
})
125146
}
126147

148+
/// Gets the topology type of the cluster.
149+
pub(crate) fn topology_type(&self) -> TopologyType {
150+
self.topology_type
151+
}
152+
127153
pub(crate) fn server_addresses(&self) -> impl Iterator<Item = &StreamAddress> {
128154
self.servers.keys()
129155
}
@@ -272,6 +298,14 @@ impl TopologyDescription {
272298
})
273299
}
274300

301+
/// Syncs the set of servers in the description to those in `hosts`. Servers in the set not
302+
/// already present in the cluster will be added, and servers in the cluster not present in the
303+
/// set will be removed.
304+
pub(crate) fn sync_hosts(&mut self, hosts: &HashSet<StreamAddress>) {
305+
self.add_new_servers_from_addresses(hosts.iter());
306+
self.servers.retain(|host, _| hosts.contains(host));
307+
}
308+
275309
/// Update the topology based on the new information about the topology contained by the
276310
/// ServerDescription.
277311
pub(crate) fn update(&mut self, mut server_description: ServerDescription) -> Result<()> {
@@ -532,17 +566,24 @@ impl TopologyDescription {
532566
}
533567

534568
/// Create a new ServerDescription for each address and add it to the topology.
535-
fn add_new_servers<'a>(&'a mut self, servers: impl Iterator<Item = &'a String>) -> Result<()> {
536-
for server in servers {
537-
let server = StreamAddress::parse(&server)?;
569+
fn add_new_servers<'a>(&mut self, servers: impl Iterator<Item = &'a String>) -> Result<()> {
570+
let servers: Result<Vec<_>> = servers.map(|server| StreamAddress::parse(server)).collect();
571+
572+
self.add_new_servers_from_addresses(servers?.iter());
573+
Ok(())
574+
}
538575

576+
/// Create a new ServerDescription for each address and add it to the topology.
577+
fn add_new_servers_from_addresses<'a>(
578+
&mut self,
579+
servers: impl Iterator<Item = &'a StreamAddress>,
580+
) {
581+
for server in servers {
539582
if !self.servers.contains_key(&server) {
540583
self.servers
541-
.insert(server.clone(), ServerDescription::new(server, None));
584+
.insert(server.clone(), ServerDescription::new(server.clone(), None));
542585
}
543586
}
544-
545-
Ok(())
546587
}
547588
}
548589

src/sdam/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod description;
33
mod message_manager;
44
mod monitor;
55
pub mod public;
6+
mod srv_polling;
67
mod state;
78

89
pub use self::public::{ServerInfo, ServerType};

src/sdam/monitor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::{
1818
RUNTIME,
1919
};
2020

21-
const DEFAULT_HEARTBEAT_FREQUENCY: Duration = Duration::from_secs(10);
21+
pub(super) const DEFAULT_HEARTBEAT_FREQUENCY: Duration = Duration::from_secs(10);
2222

2323
pub(crate) const MIN_HEARTBEAT_FREQUENCY: Duration = Duration::from_millis(500);
2424

src/sdam/srv_polling/mod.rs

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#[cfg(test)]
2+
mod test;
3+
4+
use std::time::Duration;
5+
6+
use super::{
7+
monitor::DEFAULT_HEARTBEAT_FREQUENCY,
8+
state::{Topology, TopologyState, WeakTopology},
9+
};
10+
use crate::{
11+
error::{Error, Result},
12+
options::{ClientOptions, StreamAddress},
13+
srv::SrvResolver,
14+
RUNTIME,
15+
};
16+
17+
const DEFAULT_RESCAN_SRV_INTERVAL: Duration = Duration::from_secs(60);
18+
19+
pub(crate) struct SrvPollingMonitor {
20+
initial_hostname: String,
21+
resolver: Option<SrvResolver>,
22+
topology: WeakTopology,
23+
rescan_interval: Option<Duration>,
24+
client_options: ClientOptions,
25+
}
26+
27+
struct LookupHosts {
28+
hosts: Vec<StreamAddress>,
29+
min_ttl: Option<Duration>,
30+
}
31+
32+
impl SrvPollingMonitor {
33+
pub(crate) fn new(topology: WeakTopology) -> Option<Self> {
34+
let client_options = topology.client_options().clone();
35+
36+
let initial_hostname = match client_options.original_srv_hostname() {
37+
Some(hostname) => hostname.clone(),
38+
None => return None,
39+
};
40+
41+
Some(Self {
42+
initial_hostname,
43+
resolver: None,
44+
topology,
45+
rescan_interval: None,
46+
client_options,
47+
})
48+
}
49+
50+
/// Starts a monitoring task that periodically performs SRV record lookups to determine if the
51+
/// set of mongos in the cluster have changed. A weak reference is used to ensure that the
52+
/// monitoring task doesn't keep the topology alive after the client has been dropped.
53+
pub(super) fn start(topology: WeakTopology) {
54+
RUNTIME.execute(async move {
55+
if let Some(mut monitor) = Self::new(topology) {
56+
monitor.execute().await;
57+
}
58+
});
59+
}
60+
61+
async fn execute(&mut self) {
62+
while let Some(topology) = self.topology.upgrade() {
63+
let state = topology.clone_state().await;
64+
65+
if state.is_sharded() || state.is_unknown() {
66+
let hosts = self.lookup_hosts().await;
67+
self.update_hosts(hosts, topology, state).await;
68+
}
69+
70+
RUNTIME
71+
.delay_for(self.rescan_interval.unwrap_or(DEFAULT_RESCAN_SRV_INTERVAL))
72+
.await;
73+
}
74+
}
75+
76+
async fn update_hosts(
77+
&mut self,
78+
lookup: Result<LookupHosts>,
79+
topology: Topology,
80+
mut topology_state: TopologyState,
81+
) {
82+
let lookup = match lookup {
83+
Ok(LookupHosts { hosts, .. }) if hosts.is_empty() => {
84+
self.no_valid_hosts(None);
85+
86+
return;
87+
}
88+
Ok(lookup) => lookup,
89+
Err(err) => {
90+
self.no_valid_hosts(Some(err));
91+
92+
return;
93+
}
94+
};
95+
96+
self.rescan_interval = lookup.min_ttl;
97+
98+
let diff =
99+
topology_state.update_hosts(&lookup.hosts.into_iter().collect(), &self.client_options);
100+
topology.update_state(diff, topology_state).await;
101+
}
102+
103+
async fn lookup_hosts(&mut self) -> Result<LookupHosts> {
104+
let initial_hostname = self.initial_hostname.clone();
105+
let resolver = self.get_or_create_srv_resolver().await?;
106+
let mut new_hosts = Vec::new();
107+
108+
for host in resolver.get_srv_hosts(&initial_hostname).await? {
109+
#[allow(clippy::single_match)]
110+
match host {
111+
Ok(host) => new_hosts.push(host),
112+
Err(_) => {
113+
// TODO RUST-230: Log error with host that was returned.
114+
}
115+
}
116+
}
117+
118+
Ok(LookupHosts {
119+
hosts: new_hosts,
120+
min_ttl: resolver
121+
.min_ttl()
122+
.map(|ttl| Duration::from_secs(ttl as u64)),
123+
})
124+
}
125+
126+
async fn get_or_create_srv_resolver(&mut self) -> Result<&mut SrvResolver> {
127+
if let Some(ref mut resolver) = self.resolver {
128+
return Ok(resolver);
129+
}
130+
131+
let resolver = SrvResolver::new().await?;
132+
133+
// Since the connection was not `Some` above, this will always insert the new connection and
134+
// return a reference to it.
135+
Ok(self.resolver.get_or_insert(resolver))
136+
}
137+
138+
fn no_valid_hosts(&mut self, _error: Option<Error>) {
139+
// TODO RUST-230: Log error/lack of valid results.
140+
141+
self.rescan_interval = Some(self.heartbeat_freq());
142+
}
143+
144+
fn heartbeat_freq(&self) -> Duration {
145+
self.client_options
146+
.heartbeat_freq
147+
.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY)
148+
}
149+
}

0 commit comments

Comments
 (0)