Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -910,11 +910,16 @@ impl Default for HostInfo {
}

impl HostInfo {
async fn resolve(self, resolver_config: Option<ResolverConfig>) -> Result<ResolvedHostInfo> {
async fn resolve(
self,
resolver_config: Option<ResolverConfig>,
srv_service_name: Option<String>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to bundle srv_service_name into resolver_config? Mechanically, they're always passed around as a pair but I could see going either way on whether it's the right thing conceptually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into this a bit, decided to keep it as is to reduce complexity in how the srv_service_name is passed through the codepath (from being parsed as part of the connection string down to the srv resolver- adding the parameter to ResolverConfig disrupts this a little bit). Thank you for pointing this out!

) -> Result<ResolvedHostInfo> {
Ok(match self {
Self::HostIdentifiers(hosts) => ResolvedHostInfo::HostIdentifiers(hosts),
Self::DnsRecord(hostname) => {
let mut resolver = SrvResolver::new(resolver_config.clone(), None).await?;
let mut resolver =
SrvResolver::new(resolver_config.clone(), srv_service_name).await?;
let config = resolver.resolve_client_options(&hostname).await?;
ResolvedHostInfo::DnsRecord { hostname, config }
}
Expand Down
4 changes: 3 additions & 1 deletion src/client/options/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ impl Action for ParseConnectionString {
options.resolver_config.clone_from(&self.resolver_config);
}

let resolved = host_info.resolve(self.resolver_config).await?;
let resolved = host_info
.resolve(self.resolver_config, options.srv_service_name.clone())
.await?;
options.hosts = match resolved {
ResolvedHostInfo::HostIdentifiers(hosts) => hosts,
ResolvedHostInfo::DnsRecord {
Expand Down
2 changes: 0 additions & 2 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
"maxPoolSize=0 does not error",
// TODO RUST-226: unskip this test
"Valid tlsCertificateKeyFilePassword is parsed correctly",
// TODO RUST-911: unskip this test
"SRV URI with custom srvServiceName",
// TODO RUST-229: unskip the following tests
"Single IP literal host without port",
"Single IP literal host with port",
Expand Down
2 changes: 1 addition & 1 deletion src/sdam/srv_polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl SrvPollingMonitor {

let resolver = SrvResolver::new(
self.client_options.resolver_config().cloned(),
Option::Some(self.client_options.clone()),
self.client_options.srv_service_name.clone(),
)
.await?;

Expand Down
27 changes: 27 additions & 0 deletions src/sdam/srv_polling/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,30 @@ async fn srv_max_hosts_random() {
assert_eq!(2, actual.len());
assert!(actual.contains(&localhost_test_build_10gen(27017)));
}

#[tokio::test]
async fn srv_service_name() {
if get_client_options().await.srv_service_name.is_none() {
log_uncaptured("skipping srv_service_name due to no custom srvServiceName");
return;
}
let mut options = ClientOptions::new_srv();
let hosts = vec![
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
];
let rescan_interval = options.original_srv_info.as_ref().cloned().unwrap().min_ttl;
options.hosts.clone_from(&hosts);
options.srv_service_name = Some("customname".to_string());
options.test_options_mut().mock_lookup_hosts = Some(make_lookup_hosts(vec![
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
]));
let mut topology = Topology::new(options).unwrap();
topology.watch().wait_until_initialized().await;
tokio::time::sleep(rescan_interval * 2).await;
assert_eq!(
hosts.into_iter().collect::<HashSet<_>>(),
topology.server_addresses()
);
}
26 changes: 10 additions & 16 deletions src/srv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@ use std::time::Duration;

#[cfg(feature = "dns-resolver")]
use crate::error::ErrorKind;
use crate::{
client::options::ResolverConfig,
error::Result,
options::{ClientOptions, ServerAddress},
};
use crate::{client::options::ResolverConfig, error::Result, options::ServerAddress};

#[derive(Debug)]
pub(crate) struct ResolvedConfig {
Expand Down Expand Up @@ -94,20 +90,20 @@ pub(crate) enum DomainMismatch {
#[cfg(feature = "dns-resolver")]
pub(crate) struct SrvResolver {
resolver: crate::runtime::AsyncResolver,
client_options: Option<ClientOptions>,
srv_service_name: Option<String>,
}

#[cfg(feature = "dns-resolver")]
impl SrvResolver {
pub(crate) async fn new(
config: Option<ResolverConfig>,
client_options: Option<ClientOptions>,
srv_service_name: Option<String>,
) -> Result<Self> {
let resolver = crate::runtime::AsyncResolver::new(config.map(|c| c.inner)).await?;

Ok(Self {
resolver,
client_options,
srv_service_name,
})
}

Expand Down Expand Up @@ -160,15 +156,13 @@ impl SrvResolver {
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
let default_service_name = "mongodb".to_string();
let service_name = match &self.client_options {
None => default_service_name,
Some(opts) => opts
.srv_service_name
let lookup_hostname = format!(
"_{}._tcp.{}",
self.srv_service_name
.clone()
.unwrap_or(default_service_name),
};
let lookup_hostname = format!("_{}._tcp.{}", service_name, original_hostname);
.unwrap_or("mongodb".to_string()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny style nit: we can do the following to avoid allocating an extra string here

self.srv_service_name
    .as_deref() // converts into an Option<&str> to get a reference to the inner string
    .unwrap_or("mongodb")

original_hostname
);
self.get_srv_hosts_unvalidated(&lookup_hostname)
.await?
.validate(original_hostname, dm)
Expand Down
18 changes: 0 additions & 18 deletions src/test/spec/initial_dns_seedlist_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ struct ParsedOptions {
}

async fn run_test(mut test_file: TestFile) {
if let Some(ref options) = test_file.options {
// TODO RUST-933: Remove this skip.
let skip = if options.srv_service_name.is_some() {
Some("srvServiceName")
} else {
None
};

if let Some(skip) = skip {
log_uncaptured(format!(
"skipping initial_dns_seedlist_discovery test case due to unsupported connection \
string option: {}",
skip,
));
return;
}
}

// "encoded-userinfo-and-db.json" specifies a database name with a question mark which is
// disallowed on Windows. See
// <https://www.mongodb.com/docs/manual/reference/limits/#restrictions-on-db-names>
Expand Down