Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ const URI_OPTIONS: &[&str] = &[
"waitqueuetimeoutms",
"wtimeoutms",
"zlibcompressionlevel",
"srvservicename",
];

/// Reserved characters as defined by [Section 2.2 of RFC-3986](https://tools.ietf.org/html/rfc3986#section-2.2).
Expand Down Expand Up @@ -521,6 +522,9 @@ pub struct ClientOptions {
/// By default, no default database is specified.
pub default_database: Option<String>,

/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
pub srv_service_name: Option<String>,

#[builder(setter(skip))]
#[derivative(Debug = "ignore")]
pub(crate) socket_timeout: Option<Duration>,
Expand Down Expand Up @@ -676,6 +680,8 @@ impl Serialize for ClientOptions {
loadbalanced: &'a Option<bool>,

srvmaxhosts: Option<i32>,

srvservicename: &'a Option<String>,
}

let client_options = ClientOptionsHelper {
Expand Down Expand Up @@ -709,6 +715,7 @@ impl Serialize for ClientOptions {
.map(|v| v.try_into())
.transpose()
.map_err(serde::ser::Error::custom)?,
srvservicename: &self.srv_service_name,
};

client_options.serialize(serializer)
Expand Down Expand Up @@ -865,6 +872,9 @@ pub struct ConnectionString {
/// Limit on the number of mongos connections that may be created for sharded topologies.
pub srv_max_hosts: Option<u32>,

/// Overrides the default "mongodb" service name for SRV lookup in both discovery and polling
pub srv_service_name: Option<String>,

wait_queue_timeout: Option<Duration>,
tls_insecure: Option<bool>,

Expand Down Expand Up @@ -900,11 +910,16 @@ impl Default for HostInfo {
}

impl HostInfo {
async fn resolve(self, resolver_config: Option<ResolverConfig>) -> Result<ResolvedHostInfo> {
async fn resolve(
self,
resolver_config: Option<ResolverConfig>,
srv_service_name: Option<String>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to bundle srv_service_name into resolver_config? Mechanically, they're always passed around as a pair but I could see going either way on whether it's the right thing conceptually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into this a bit, decided to keep it as is to reduce complexity in how the srv_service_name is passed through the codepath (from being parsed as part of the connection string down to the srv resolver- adding the parameter to ResolverConfig disrupts this a little bit). Thank you for pointing this out!

) -> Result<ResolvedHostInfo> {
Ok(match self {
Self::HostIdentifiers(hosts) => ResolvedHostInfo::HostIdentifiers(hosts),
Self::DnsRecord(hostname) => {
let mut resolver = SrvResolver::new(resolver_config.clone()).await?;
let mut resolver =
SrvResolver::new(resolver_config.clone(), srv_service_name).await?;
let config = resolver.resolve_client_options(&hostname).await?;
ResolvedHostInfo::DnsRecord { hostname, config }
}
Expand Down Expand Up @@ -1486,6 +1501,12 @@ impl ConnectionString {
ConnectionStringParts::default()
};

if conn_str.srv_service_name.is_some() && !srv {
return Err(Error::invalid_argument(
"srvServiceName cannot be specified with a non-SRV URI",
));
}

if let Some(srv_max_hosts) = conn_str.srv_max_hosts {
if !srv {
return Err(Error::invalid_argument(
Expand Down Expand Up @@ -1976,6 +1997,9 @@ impl ConnectionString {
k @ "srvmaxhosts" => {
self.srv_max_hosts = Some(get_u32!(value, k));
}
"srvservicename" => {
self.srv_service_name = Some(value.to_string());
}
k @ "tls" | k @ "ssl" => {
let tls = get_bool!(value, k);

Expand Down
5 changes: 4 additions & 1 deletion src/client/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ impl Action for ParseConnectionString {
options.resolver_config.clone_from(&self.resolver_config);
}

let resolved = host_info.resolve(self.resolver_config).await?;
let resolved = host_info
.resolve(self.resolver_config, options.srv_service_name.clone())
.await?;
options.hosts = match resolved {
ResolvedHostInfo::HostIdentifiers(hosts) => hosts,
ResolvedHostInfo::DnsRecord {
Expand Down Expand Up @@ -159,6 +161,7 @@ impl ClientOptions {
#[cfg(feature = "tracing-unstable")]
tracing_max_document_length_bytes: None,
srv_max_hosts: conn_str.srv_max_hosts,
srv_service_name: conn_str.srv_service_name,
}
}
}
2 changes: 0 additions & 2 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
"maxPoolSize=0 does not error",
// TODO RUST-226: unskip this test
"Valid tlsCertificateKeyFilePassword is parsed correctly",
// TODO RUST-911: unskip this test
"SRV URI with custom srvServiceName",
// TODO RUST-229: unskip the following tests
"Single IP literal host without port",
"Single IP literal host with port",
Expand Down
6 changes: 5 additions & 1 deletion src/sdam/srv_polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ impl SrvPollingMonitor {
return Ok(resolver);
}

let resolver = SrvResolver::new(self.client_options.resolver_config().cloned()).await?;
let resolver = SrvResolver::new(
self.client_options.resolver_config().cloned(),
self.client_options.srv_service_name.clone(),
)
.await?;

// Since the connection was not `Some` above, this will always insert the new connection and
// return a reference to it.
Expand Down
48 changes: 42 additions & 6 deletions src/sdam/srv_polling/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,29 @@ static DEFAULT_HOSTS: Lazy<Vec<ServerAddress>> = Lazy::new(|| {
});

async fn run_test(new_hosts: Result<Vec<ServerAddress>>, expected_hosts: HashSet<ServerAddress>) {
run_test_srv(None, new_hosts, expected_hosts).await
run_test_srv(None, new_hosts, expected_hosts, None).await
}

async fn run_test_srv(
max_hosts: Option<u32>,
new_hosts: Result<Vec<ServerAddress>>,
expected_hosts: HashSet<ServerAddress>,
srv_service_name: Option<String>,
) {
let actual = run_test_extra(max_hosts, new_hosts).await;
let actual = run_test_extra(max_hosts, new_hosts, srv_service_name).await;
assert_eq!(expected_hosts, actual);
}

async fn run_test_extra(
max_hosts: Option<u32>,
new_hosts: Result<Vec<ServerAddress>>,
srv_service_name: Option<String>,
) -> HashSet<ServerAddress> {
let mut options = ClientOptions::new_srv();
options.hosts.clone_from(&DEFAULT_HOSTS);
options.test_options_mut().disable_monitoring_threads = true;
options.srv_max_hosts = max_hosts;
options.srv_service_name = srv_service_name;
let mut topology = Topology::new(options.clone()).unwrap();
topology.watch().wait_until_initialized().await;
let mut monitor =
Expand Down Expand Up @@ -156,8 +159,20 @@ async fn srv_max_hosts_zero() {
localhost_test_build_10gen(27020),
];

run_test_srv(None, Ok(hosts.clone()), hosts.clone().into_iter().collect()).await;
run_test_srv(Some(0), Ok(hosts.clone()), hosts.into_iter().collect()).await;
run_test_srv(
None,
Ok(hosts.clone()),
hosts.clone().into_iter().collect(),
None,
)
.await;
run_test_srv(
Some(0),
Ok(hosts.clone()),
hosts.into_iter().collect(),
None,
)
.await;
}

// SRV polling with srvMaxHosts MongoClient option: All DNS records are selected (srvMaxHosts >=
Expand All @@ -169,7 +184,13 @@ async fn srv_max_hosts_gt_actual() {
localhost_test_build_10gen(27020),
];

run_test_srv(Some(2), Ok(hosts.clone()), hosts.into_iter().collect()).await;
run_test_srv(
Some(2),
Ok(hosts.clone()),
hosts.into_iter().collect(),
None,
)
.await;
}

// SRV polling with srvMaxHosts MongoClient option: New DNS records are randomly selected
Expand All @@ -182,7 +203,22 @@ async fn srv_max_hosts_random() {
localhost_test_build_10gen(27020),
];

let actual = run_test_extra(Some(2), Ok(hosts)).await;
let actual = run_test_extra(Some(2), Ok(hosts), None).await;
assert_eq!(2, actual.len());
assert!(actual.contains(&localhost_test_build_10gen(27017)));
}

#[tokio::test]
async fn srv_service_name() {
let hosts = vec![
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
];
run_test_srv(
None,
Ok(hosts.clone()),
hosts.into_iter().collect(),
Some("customname".to_string()),
)
.await;
}
19 changes: 16 additions & 3 deletions src/srv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,21 @@ pub(crate) enum DomainMismatch {
#[cfg(feature = "dns-resolver")]
pub(crate) struct SrvResolver {
resolver: crate::runtime::AsyncResolver,
srv_service_name: Option<String>,
}

#[cfg(feature = "dns-resolver")]
impl SrvResolver {
pub(crate) async fn new(config: Option<ResolverConfig>) -> Result<Self> {
pub(crate) async fn new(
config: Option<ResolverConfig>,
srv_service_name: Option<String>,
) -> Result<Self> {
let resolver = crate::runtime::AsyncResolver::new(config.map(|c| c.inner)).await?;

Ok(Self { resolver })
Ok(Self {
resolver,
srv_service_name,
})
}

pub(crate) async fn resolve_client_options(
Expand Down Expand Up @@ -149,7 +156,13 @@ impl SrvResolver {
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);
let lookup_hostname = format!(
"_{}._tcp.{}",
self.srv_service_name
.clone()
.unwrap_or("mongodb".to_string()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny style nit: we can do the following to avoid allocating an extra string here

self.srv_service_name
    .as_deref() // converts into an Option<&str> to get a reference to the inner string
    .unwrap_or("mongodb")

original_hostname
);
self.get_srv_hosts_unvalidated(&lookup_hostname)
.await?
.validate(original_hostname, dm)
Expand Down
18 changes: 0 additions & 18 deletions src/test/spec/initial_dns_seedlist_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ struct ParsedOptions {
}

async fn run_test(mut test_file: TestFile) {
if let Some(ref options) = test_file.options {
// TODO RUST-933: Remove this skip.
let skip = if options.srv_service_name.is_some() {
Some("srvServiceName")
} else {
None
};

if let Some(skip) = skip {
log_uncaptured(format!(
"skipping initial_dns_seedlist_discovery test case due to unsupported connection \
string option: {}",
skip,
));
return;
}
}

// "encoded-userinfo-and-db.json" specifies a database name with a question mark which is
// disallowed on Windows. See
// <https://www.mongodb.com/docs/manual/reference/limits/#restrictions-on-db-names>
Expand Down
Loading