Skip to content

Commit 2b7175e

Browse files
authored
RUST-950 Enable connection to a load balancer (#415)
1 parent 7a305e8 commit 2b7175e

File tree

22 files changed

+524
-144
lines changed

22 files changed

+524
-144
lines changed

src/client/options/mod.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,16 @@ pub struct ClientOptions {
556556
#[builder(default)]
557557
#[cfg(test)]
558558
pub(crate) heartbeat_freq_test: Option<Duration>,
559+
560+
/// Allow use of the `load_balanced` option.
561+
// TODO RUST-653 Remove this when load balancer work is ready for release.
562+
#[builder(default, setter(skip))]
563+
#[serde(skip)]
564+
pub(crate) allow_load_balanced: bool,
565+
566+
/// Whether or not the client is connecting to a MongoDB cluster through a load balancer.
567+
#[builder(default, setter(skip))]
568+
pub(crate) load_balanced: Option<bool>,
559569
}
560570

561571
fn default_hosts() -> Vec<ServerAddress> {
@@ -689,6 +699,7 @@ struct ClientOptionsParser {
689699
auth_mechanism_properties: Option<Document>,
690700
read_preference: Option<ReadPreference>,
691701
read_preference_tags: Option<Vec<TagSet>>,
702+
load_balanced: Option<bool>,
692703
original_uri: String,
693704
}
694705

@@ -921,6 +932,8 @@ impl From<ClientOptionsParser> for ClientOptions {
921932
server_api: None,
922933
#[cfg(test)]
923934
heartbeat_freq_test: None,
935+
allow_load_balanced: false,
936+
load_balanced: parser.load_balanced,
924937
}
925938
}
926939
}
@@ -1086,6 +1099,10 @@ impl ClientOptions {
10861099
options.repl_set_name = Some(replica_set);
10871100
}
10881101
}
1102+
1103+
if options.load_balanced.is_none() {
1104+
options.load_balanced = config.load_balanced;
1105+
}
10891106
}
10901107

10911108
options.validate()?;
@@ -1108,7 +1125,7 @@ impl ClientOptions {
11081125
}
11091126
}
11101127

1111-
/// Ensure the options set are valid, returning an error descirbing the problem if they are not.
1128+
/// Ensure the options set are valid, returning an error describing the problem if they are not.
11121129
pub(crate) fn validate(&self) -> Result<()> {
11131130
if let Some(true) = self.direct_connection {
11141131
if self.hosts.len() > 1 {
@@ -1122,6 +1139,36 @@ impl ClientOptions {
11221139
if let Some(ref write_concern) = self.write_concern {
11231140
write_concern.validate()?;
11241141
}
1142+
1143+
if !self.allow_load_balanced && self.load_balanced.is_some() {
1144+
return Err(ErrorKind::InvalidArgument {
1145+
message: "loadBalanced is not supported".to_string(),
1146+
}
1147+
.into());
1148+
}
1149+
1150+
if self.load_balanced.unwrap_or(false) {
1151+
if self.hosts.len() > 1 {
1152+
return Err(ErrorKind::InvalidArgument {
1153+
message: "cannot specify multiple seeds with loadBalanced=true".to_string(),
1154+
}
1155+
.into());
1156+
}
1157+
if self.repl_set_name.is_some() {
1158+
return Err(ErrorKind::InvalidArgument {
1159+
message: "cannot specify replicaSet with loadBalanced=true".to_string(),
1160+
}
1161+
.into());
1162+
}
1163+
if self.direct_connection == Some(true) {
1164+
return Err(ErrorKind::InvalidArgument {
1165+
message: "cannot specify directConnection=true with loadBalanced=true"
1166+
.to_string(),
1167+
}
1168+
.into());
1169+
}
1170+
}
1171+
11251172
Ok(())
11261173
}
11271174

@@ -1677,6 +1724,9 @@ impl ClientOptionsParser {
16771724
let mut write_concern = self.write_concern.get_or_insert_with(Default::default);
16781725
write_concern.journal = Some(get_bool!(value, k));
16791726
}
1727+
k @ "loadbalanced" => {
1728+
self.load_balanced = Some(get_bool!(value, k));
1729+
}
16801730
k @ "localthresholdms" => {
16811731
self.local_threshold = Some(Duration::from_millis(get_duration!(value, k)))
16821732
}

src/cmap/conn/mod.rs

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ use derivative::Derivative;
1212
use self::wire::Message;
1313
use super::manager::PoolManager;
1414
use crate::{
15-
cmap::options::{ConnectionOptions, StreamOptions},
16-
error::{ErrorKind, Result},
15+
bson::oid::ObjectId,
16+
cmap::{
17+
options::{ConnectionOptions, StreamOptions},
18+
PoolGeneration,
19+
},
20+
error::{load_balanced_mode_mismatch, ErrorKind, Result},
1721
event::cmap::{
1822
CmapEventHandler,
1923
ConnectionCheckedInEvent,
@@ -46,7 +50,7 @@ pub struct ConnectionInfo {
4650
pub(crate) struct Connection {
4751
pub(super) id: u32,
4852
pub(super) address: ServerAddress,
49-
pub(crate) generation: u32,
53+
pub(crate) generation: ConnectionGeneration,
5054

5155
/// The cached StreamDescription from the connection's handshake.
5256
pub(super) stream_description: Option<StreamDescription>,
@@ -90,7 +94,7 @@ impl Connection {
9094

9195
let conn = Self {
9296
id,
93-
generation,
97+
generation: ConnectionGeneration::Normal(generation),
9498
pool_manager: None,
9599
command_executing: false,
96100
ready_and_available_time: None,
@@ -106,10 +110,16 @@ impl Connection {
106110

107111
/// Constructs and connects a new connection.
108112
pub(super) async fn connect(pending_connection: PendingConnection) -> Result<Self> {
113+
let generation = match pending_connection.generation {
114+
PoolGeneration::Normal(gen) => gen,
115+
PoolGeneration::LoadBalanced(_) => 0, /* Placeholder; will be overwritten in
116+
* `ConnectionEstablisher::
117+
* establish_connection`. */
118+
};
109119
Self::new(
110120
pending_connection.id,
111121
pending_connection.address.clone(),
112-
pending_connection.generation,
122+
generation,
113123
pending_connection.options,
114124
)
115125
.await
@@ -181,11 +191,6 @@ impl Connection {
181191
.unwrap_or(false)
182192
}
183193

184-
/// Checks if the connection is stale.
185-
pub(super) fn is_stale(&self, current_generation: u32) -> bool {
186-
self.generation != current_generation
187-
}
188-
189194
/// Checks if the connection is currently executing an operation.
190195
pub(super) fn is_executing(&self) -> bool {
191196
self.command_executing
@@ -300,7 +305,7 @@ impl Connection {
300305
Connection {
301306
id: self.id,
302307
address: self.address.clone(),
303-
generation: self.generation,
308+
generation: self.generation.clone(),
304309
stream: std::mem::replace(&mut self.stream, AsyncStream::Null),
305310
handler: self.handler.take(),
306311
stream_description: self.stream_description.take(),
@@ -335,6 +340,38 @@ impl Drop for Connection {
335340
}
336341
}
337342

343+
#[derive(Debug, Clone)]
344+
pub(crate) enum ConnectionGeneration {
345+
Normal(u32),
346+
LoadBalanced {
347+
generation: u32,
348+
service_id: ObjectId,
349+
},
350+
}
351+
352+
impl ConnectionGeneration {
353+
pub(crate) fn service_id(&self) -> Option<ObjectId> {
354+
match self {
355+
ConnectionGeneration::Normal(_) => None,
356+
ConnectionGeneration::LoadBalanced { service_id, .. } => Some(*service_id),
357+
}
358+
}
359+
360+
pub(crate) fn is_stale(&self, current_generation: &PoolGeneration) -> bool {
361+
match (self, current_generation) {
362+
(ConnectionGeneration::Normal(cgen), PoolGeneration::Normal(pgen)) => cgen != pgen,
363+
(
364+
ConnectionGeneration::LoadBalanced {
365+
generation: cgen,
366+
service_id,
367+
},
368+
PoolGeneration::LoadBalanced(gen_map),
369+
) => cgen != gen_map.get(service_id).unwrap_or(&0),
370+
_ => load_balanced_mode_mismatch!(false),
371+
}
372+
}
373+
}
374+
338375
/// Struct encapsulating the information needed to establish a `Connection`.
339376
///
340377
/// Creating a `PendingConnection` contributes towards the total connection count of a pool, despite
@@ -344,7 +381,7 @@ impl Drop for Connection {
344381
pub(super) struct PendingConnection {
345382
pub(super) id: u32,
346383
pub(super) address: ServerAddress,
347-
pub(super) generation: u32,
384+
pub(super) generation: PoolGeneration,
348385
pub(super) options: Option<ConnectionOptions>,
349386
}
350387

src/cmap/establish/handshake/mod.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
bson::{doc, Bson, Document},
99
client::auth::{ClientFirst, FirstRound},
1010
cmap::{options::ConnectionPoolOptions, Command, Connection, StreamDescription},
11-
error::Result,
11+
error::{ErrorKind, Result},
1212
is_master::{is_master_command, run_is_master, IsMasterReply},
1313
options::{AuthMechanism, ClientOptions, Credential, DriverInfo, ServerApi},
1414
};
@@ -177,6 +177,10 @@ impl Handshaker {
177177
command.target_db = cred.resolved_source().to_string();
178178
credential = Some(cred);
179179
}
180+
181+
if options.load_balanced {
182+
command.body.insert("loadBalanced", true);
183+
}
180184
}
181185

182186
command.body.insert("client", metadata);
@@ -194,6 +198,16 @@ impl Handshaker {
194198
let client_first = set_speculative_auth_info(&mut command.body, self.credential.as_ref())?;
195199

196200
let mut is_master_reply = run_is_master(command, conn).await?;
201+
if self.command.body.contains_key("loadBalanced")
202+
&& is_master_reply.command_response.service_id.is_none()
203+
{
204+
return Err(ErrorKind::IncompatibleServer {
205+
message: "Driver attempted to initialize in load balancing mode, but the server \
206+
does not support this mode."
207+
.to_string(),
208+
}
209+
.into());
210+
}
197211
conn.stream_description = Some(StreamDescription::from_is_master(is_master_reply.clone()));
198212

199213
// Record the client's message and the server's response from speculative authentication if
@@ -232,6 +246,7 @@ pub(crate) struct HandshakerOptions {
232246
credential: Option<Credential>,
233247
driver_info: Option<DriverInfo>,
234248
server_api: Option<ServerApi>,
249+
load_balanced: bool,
235250
}
236251

237252
impl From<ConnectionPoolOptions> for HandshakerOptions {
@@ -241,6 +256,7 @@ impl From<ConnectionPoolOptions> for HandshakerOptions {
241256
credential: options.credential,
242257
driver_info: options.driver_info,
243258
server_api: options.server_api,
259+
load_balanced: options.load_balanced.unwrap_or(false),
244260
}
245261
}
246262
}
@@ -252,6 +268,7 @@ impl From<ClientOptions> for HandshakerOptions {
252268
credential: options.credential,
253269
driver_info: options.driver_info,
254270
server_api: options.server_api,
271+
load_balanced: options.load_balanced.unwrap_or(false),
255272
}
256273
}
257274
}

src/cmap/establish/mod.rs

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@ pub(super) mod handshake;
33
mod test;
44

55
use self::handshake::Handshaker;
6-
use super::{conn::PendingConnection, options::ConnectionPoolOptions, Connection};
6+
use super::{
7+
conn::{ConnectionGeneration, PendingConnection},
8+
options::ConnectionPoolOptions,
9+
Connection,
10+
PoolGeneration,
11+
};
712
use crate::{
813
client::{auth::Credential, options::ServerApi},
9-
error::Result,
14+
error::{Error as MongoError, ErrorKind},
1015
runtime::HttpClient,
16+
sdam::HandshakePhase,
1117
};
1218

1319
/// Contains the logic to establish a connection, including handshaking, authenticating, and
@@ -38,26 +44,73 @@ impl ConnectionEstablisher {
3844
pub(super) async fn establish_connection(
3945
&self,
4046
pending_connection: PendingConnection,
41-
) -> Result<Connection> {
42-
let mut connection = Connection::connect(pending_connection).await?;
47+
) -> std::result::Result<Connection, EstablishError> {
48+
let pool_gen = pending_connection.generation.clone();
49+
let mut connection = Connection::connect(pending_connection)
50+
.await
51+
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
4352

44-
let first_round = self
53+
let handshake = self
4554
.handshaker
4655
.handshake(&mut connection)
47-
.await?
48-
.first_round;
56+
.await
57+
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
58+
let service_id = handshake.is_master_reply.command_response.service_id;
59+
60+
// If the handshake response had a `serviceId` field, this is a connection to a load
61+
// balancer and must derive its generation from the service_generations map.
62+
match (pool_gen, service_id) {
63+
(PoolGeneration::Normal(_), _) => {}
64+
(PoolGeneration::LoadBalanced(gen_map), Some(service_id)) => {
65+
connection.generation = ConnectionGeneration::LoadBalanced {
66+
generation: *gen_map.get(&service_id).unwrap_or(&0),
67+
service_id,
68+
};
69+
}
70+
_ => {
71+
return Err(EstablishError::post_hello(
72+
ErrorKind::Internal {
73+
message: "load-balanced mode mismatch".to_string(),
74+
}
75+
.into(),
76+
connection.generation.clone(),
77+
));
78+
}
79+
}
4980

5081
if let Some(ref credential) = self.credential {
5182
credential
5283
.authenticate_stream(
5384
&mut connection,
5485
&self.http_client,
5586
self.server_api.as_ref(),
56-
first_round,
87+
handshake.first_round,
5788
)
58-
.await?;
89+
.await
90+
.map_err(|e| EstablishError::post_hello(e, connection.generation.clone()))?
5991
}
6092

6193
Ok(connection)
6294
}
6395
}
96+
97+
#[derive(Debug, Clone)]
98+
pub(crate) struct EstablishError {
99+
pub(crate) cause: MongoError,
100+
pub(crate) handshake_phase: HandshakePhase,
101+
}
102+
103+
impl EstablishError {
104+
fn pre_hello(cause: MongoError, generation: PoolGeneration) -> Self {
105+
Self {
106+
cause,
107+
handshake_phase: HandshakePhase::PreHello { generation },
108+
}
109+
}
110+
fn post_hello(cause: MongoError, generation: ConnectionGeneration) -> Self {
111+
Self {
112+
cause,
113+
handshake_phase: HandshakePhase::PostHello { generation },
114+
}
115+
}
116+
}

0 commit comments

Comments
 (0)