diff --git a/src/client.rs b/src/client.rs index b34b6edf8..1a4e984af 100644 --- a/src/client.rs +++ b/src/client.rs @@ -35,6 +35,7 @@ use crate::{ error::{Error, ErrorKind, Result}, event::command::CommandEvent, id_set::IdSet, + operation::OverrideCriteriaFn, options::{ClientOptions, DatabaseOptions, ReadPreference, SelectionCriteria, ServerAddress}, sdam::{ server_selection::{self, attempt_to_select_server}, @@ -446,8 +447,8 @@ impl Client { &self, criteria: Option<&SelectionCriteria>, ) -> Result { - let server = self - .select_server(criteria, "Test select server", None) + let (server, _) = self + .select_server(criteria, "Test select server", None, |_, _| None) .await?; Ok(server.address.clone()) } @@ -460,7 +461,8 @@ impl Client { #[allow(unused_variables)] // we only use the operation_name for tracing. operation_name: &str, deprioritized: Option<&ServerAddress>, - ) -> Result { + override_criteria: OverrideCriteriaFn, + ) -> Result<(SelectedServer, SelectionCriteria)> { let criteria = criteria.unwrap_or(&SelectionCriteria::ReadPreference(ReadPreference::Primary)); @@ -488,9 +490,16 @@ impl Client { let mut watcher = self.inner.topology.watch(); loop { let state = watcher.observe_latest(); - + let override_slot; + let effective_criteria = + if let Some(oc) = override_criteria(criteria, &state.description) { + override_slot = oc; + &override_slot + } else { + criteria + }; let result = server_selection::attempt_to_select_server( - criteria, + effective_criteria, &state.description, &state.servers(), deprioritized, @@ -507,7 +516,7 @@ impl Client { #[cfg(feature = "tracing-unstable")] event_emitter.emit_succeeded_event(&state.description, &server); - return Ok(server); + return Ok((server, effective_criteria.clone())); } else { #[cfg(feature = "tracing-unstable")] if !emitted_waiting_message { diff --git a/src/client/executor.rs b/src/client/executor.rs index 17b5debd2..87a0f7209 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -59,7 +59,7 @@ use crate::{ Retryability, }, options::{ChangeStreamOptions, SelectionCriteria}, - sdam::{HandshakePhase, SelectedServer, ServerType, TopologyType, TransactionSupportStatus}, + sdam::{HandshakePhase, ServerType, TopologyType, TransactionSupportStatus}, selection_criteria::ReadPreference, tracking_arc::TrackingArc, ClusterTime, @@ -318,15 +318,16 @@ impl Client { .and_then(|s| s.transaction.pinned_mongos()) .or_else(|| op.selection_criteria()); - let server = match self + let (server, effective_criteria) = match self .select_server( selection_criteria, op.name(), retry.as_ref().map(|r| &r.first_server), + op.override_criteria(), ) .await { - Ok(server) => server, + Ok(out) => out, Err(mut err) => { retry.first_error()?; @@ -398,6 +399,7 @@ impl Client { &mut session, txn_number, retryability, + effective_criteria, ) .await { @@ -471,127 +473,21 @@ impl Client { session: &mut Option<&mut ClientSession>, txn_number: Option, retryability: Retryability, + effective_criteria: SelectionCriteria, ) -> Result { loop { - let stream_description = connection.stream_description()?; - let is_sharded = stream_description.initial_server_type == ServerType::Mongos; - let mut cmd = op.build(stream_description)?; - self.inner.topology.update_command_with_read_pref( - connection.address(), - &mut cmd, - op.selection_criteria(), - ); - - match session { - Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { - cmd.set_session(session); - if let Some(txn_number) = txn_number { - cmd.set_txn_number(txn_number); - } - if session - .options() - .and_then(|opts| opts.snapshot) - .unwrap_or(false) - { - if connection - .stream_description()? - .max_wire_version - .unwrap_or(0) - < 13 - { - let labels: Option> = None; - return Err(Error::new( - ErrorKind::IncompatibleServer { - message: "Snapshot reads require MongoDB 5.0 or later".into(), - }, - labels, - )); - } - cmd.set_snapshot_read_concern(session); - } - // If this is a causally consistent session, set `readConcern.afterClusterTime`. - // Causal consistency defaults to true, unless snapshot is true. - else if session.causal_consistency() - && matches!( - session.transaction.state, - TransactionState::None | TransactionState::Starting - ) - && op.supports_read_concern(stream_description) - { - cmd.set_after_cluster_time(session); - } - - match session.transaction.state { - TransactionState::Starting => { - cmd.set_start_transaction(); - cmd.set_autocommit(); - if session.causal_consistency() { - cmd.set_after_cluster_time(session); - } - - if let Some(ref options) = session.transaction.options { - if let Some(ref read_concern) = options.read_concern { - cmd.set_read_concern_level(read_concern.level.clone()); - } - } - if self.is_load_balanced() { - session.pin_connection(connection.pin()?); - } else if is_sharded { - session.pin_mongos(connection.address().clone()); - } - session.transaction.state = TransactionState::InProgress; - } - TransactionState::InProgress => cmd.set_autocommit(), - TransactionState::Committed { .. } | TransactionState::Aborted => { - cmd.set_autocommit(); - - // Append the recovery token to the command if we are committing or - // aborting on a sharded transaction. - if is_sharded { - if let Some(ref recovery_token) = session.transaction.recovery_token - { - cmd.set_recovery_token(recovery_token); - } - } - } - _ => {} - } - session.update_last_use(); - } - Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: format!("{} does not support sessions", cmd.name), - } - .into()); - } - Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: "Cannot use ClientSessions with unacknowledged write concern" - .to_string(), - } - .into()); - } - _ => {} - } - - let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); - let client_cluster_time = self.inner.topology.cluster_time(); - let max_cluster_time = - std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); - if let Some(cluster_time) = max_cluster_time { - cmd.set_cluster_time(cluster_time); - } + let cmd = self.build_command( + op, + connection, + session, + txn_number, + effective_criteria.clone(), + )?; let connection_info = connection.info(); let service_id = connection.service_id(); let request_id = next_request_id(); - - if let Some(ref server_api) = self.inner.options.server_api { - cmd.set_server_api(server_api); - } - let should_redact = cmd.should_redact(); - let cmd_name = cmd.name.clone(); let target_db = cmd.target_db.clone(); @@ -630,8 +526,9 @@ impl Client { let start_time = Instant::now(); let command_result = match connection.send_message(message).await { Ok(response) => { - self.handle_response(op, session, is_sharded, response) - .await + let is_sharded = + connection.stream_description()?.initial_server_type == ServerType::Mongos; + self.parse_response(op, session, is_sharded, response).await } Err(err) => Err(err), }; @@ -706,6 +603,7 @@ impl Client { let context = ExecutionContext { connection, session: session.as_deref_mut(), + effective_criteria: effective_criteria.clone(), }; match op.handle_response(response, context).await { @@ -737,6 +635,128 @@ impl Client { } } + fn build_command( + &self, + op: &mut T, + connection: &mut PooledConnection, + session: &mut Option<&mut ClientSession>, + txn_number: Option, + effective_criteria: SelectionCriteria, + ) -> Result { + let stream_description = connection.stream_description()?; + let is_sharded = stream_description.initial_server_type == ServerType::Mongos; + let mut cmd = op.build(stream_description)?; + self.inner.topology.update_command_with_read_pref( + connection.address(), + &mut cmd, + &effective_criteria, + ); + + match session { + Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { + cmd.set_session(session); + if let Some(txn_number) = txn_number { + cmd.set_txn_number(txn_number); + } + if session + .options() + .and_then(|opts| opts.snapshot) + .unwrap_or(false) + { + if connection + .stream_description()? + .max_wire_version + .unwrap_or(0) + < 13 + { + let labels: Option> = None; + return Err(Error::new( + ErrorKind::IncompatibleServer { + message: "Snapshot reads require MongoDB 5.0 or later".into(), + }, + labels, + )); + } + cmd.set_snapshot_read_concern(session); + } + // If this is a causally consistent session, set `readConcern.afterClusterTime`. + // Causal consistency defaults to true, unless snapshot is true. + else if session.causal_consistency() + && matches!( + session.transaction.state, + TransactionState::None | TransactionState::Starting + ) + && op.supports_read_concern(stream_description) + { + cmd.set_after_cluster_time(session); + } + + match session.transaction.state { + TransactionState::Starting => { + cmd.set_start_transaction(); + cmd.set_autocommit(); + if session.causal_consistency() { + cmd.set_after_cluster_time(session); + } + + if let Some(ref options) = session.transaction.options { + if let Some(ref read_concern) = options.read_concern { + cmd.set_read_concern_level(read_concern.level.clone()); + } + } + if self.is_load_balanced() { + session.pin_connection(connection.pin()?); + } else if is_sharded { + session.pin_mongos(connection.address().clone()); + } + session.transaction.state = TransactionState::InProgress; + } + TransactionState::InProgress => cmd.set_autocommit(), + TransactionState::Committed { .. } | TransactionState::Aborted => { + cmd.set_autocommit(); + + // Append the recovery token to the command if we are committing or aborting + // on a sharded transaction. + if is_sharded { + if let Some(ref recovery_token) = session.transaction.recovery_token { + cmd.set_recovery_token(recovery_token); + } + } + } + _ => {} + } + session.update_last_use(); + } + Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: format!("{} does not support sessions", cmd.name), + } + .into()); + } + Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: "Cannot use ClientSessions with unacknowledged write concern" + .to_string(), + } + .into()); + } + _ => {} + } + + let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); + let client_cluster_time = self.inner.topology.cluster_time(); + let max_cluster_time = std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); + if let Some(cluster_time) = max_cluster_time { + cmd.set_cluster_time(cluster_time); + } + + if let Some(ref server_api) = self.inner.options.server_api { + cmd.set_server_api(server_api); + } + + Ok(cmd) + } + #[cfg(feature = "in-use-encryption")] fn auto_encrypt<'a>( &'a self, @@ -789,7 +809,7 @@ impl Client { .await } - async fn handle_response( + async fn parse_response( &self, op: &T, session: &mut Option<&mut ClientSession>, @@ -864,8 +884,8 @@ impl Client { (matches!(topology_type, TopologyType::Single) && server_type.is_available()) || server_type.is_data_bearing() })); - let _: SelectedServer = self - .select_server(Some(&criteria), operation_name, None) + let _ = self + .select_server(Some(&criteria), operation_name, None, |_, _| None) .await?; Ok(()) } diff --git a/src/operation.rs b/src/operation.rs index f0af1b1f6..287d9aebe 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -76,6 +76,7 @@ pub(crate) use update::{Update, UpdateOrReplace}; const SERVER_4_2_0_WIRE_VERSION: i32 = 8; const SERVER_4_4_0_WIRE_VERSION: i32 = 9; +const SERVER_5_0_0_WIRE_VERSION: i32 = 13; const SERVER_8_0_0_WIRE_VERSION: i32 = 25; // The maximum number of bytes that may be included in a write payload when auto-encryption is // enabled. @@ -88,6 +89,7 @@ const OP_MSG_OVERHEAD_BYTES: usize = 1_000; pub(crate) struct ExecutionContext<'a> { pub(crate) connection: &'a mut PooledConnection, pub(crate) session: Option<&'a mut ClientSession>, + pub(crate) effective_criteria: SelectionCriteria, } #[derive(Debug, PartialEq, Clone, Copy)] @@ -148,11 +150,18 @@ pub(crate) trait Operation { /// Updates this operation as needed for a retry. fn update_for_retry(&mut self); + /// Returns a function handle to potentially override selection criteria based on server + /// topology. + fn override_criteria(&self) -> OverrideCriteriaFn; + fn pinned_connection(&self) -> Option<&PinnedConnectionHandle>; fn name(&self) -> &str; } +pub(crate) type OverrideCriteriaFn = + fn(&SelectionCriteria, &crate::sdam::TopologyDescription) -> Option; + // A mirror of the `Operation` trait, with default behavior where appropriate. Should only be // implemented by operation types that do not delegate to other operations. pub(crate) trait OperationWithDefaults: Send + Sync { @@ -235,6 +244,12 @@ pub(crate) trait OperationWithDefaults: Send + Sync { /// Updates this operation as needed for a retry. fn update_for_retry(&mut self) {} + /// Returns a function handle to potentially override selection criteria based on server + /// topology. + fn override_criteria(&self) -> OverrideCriteriaFn { + |_, _| None + } + fn pinned_connection(&self) -> Option<&PinnedConnectionHandle> { None } @@ -287,6 +302,9 @@ where fn update_for_retry(&mut self) { self.update_for_retry() } + fn override_criteria(&self) -> OverrideCriteriaFn { + self.override_criteria() + } fn pinned_connection(&self) -> Option<&PinnedConnectionHandle> { self.pinned_connection() } diff --git a/src/operation/aggregate.rs b/src/operation/aggregate.rs index dd7568523..17b0277d8 100644 --- a/src/operation/aggregate.rs +++ b/src/operation/aggregate.rs @@ -7,7 +7,7 @@ use crate::{ cursor::CursorSpecification, error::Result, operation::{append_options, remove_empty_write_concern, Retryability}, - options::{AggregateOptions, SelectionCriteria, WriteConcern}, + options::{AggregateOptions, ReadPreference, SelectionCriteria, WriteConcern}, Namespace, }; @@ -134,6 +134,27 @@ impl OperationWithDefaults for Aggregate { Retryability::Read } } + + fn override_criteria(&self) -> super::OverrideCriteriaFn { + if !self.is_out_or_merge() { + return |_, _| None; + } + |criteria, topology| { + if criteria == &SelectionCriteria::ReadPreference(ReadPreference::Primary) + || topology.topology_type() == crate::TopologyType::LoadBalanced + { + return None; + } + for server in topology.servers.values() { + if let Ok(Some(v)) = server.max_wire_version() { + if v < super::SERVER_5_0_0_WIRE_VERSION { + return Some(SelectionCriteria::ReadPreference(ReadPreference::Primary)); + } + } + } + None + } + } } impl Aggregate { diff --git a/src/operation/aggregate/change_stream.rs b/src/operation/aggregate/change_stream.rs index 00d56e56d..4cb67cb85 100644 --- a/src/operation/aggregate/change_stream.rs +++ b/src/operation/aggregate/change_stream.rs @@ -94,6 +94,7 @@ impl OperationWithDefaults for ChangeStreamAggregate { let inner_context = ExecutionContext { connection: context.connection, session: context.session.as_deref_mut(), + effective_criteria: context.effective_criteria, }; let spec = self.inner.handle_response(response, inner_context)?; diff --git a/src/operation/bulk_write.rs b/src/operation/bulk_write.rs index 2f7c7b1bb..2b5c9c9ed 100644 --- a/src/operation/bulk_write.rs +++ b/src/operation/bulk_write.rs @@ -114,6 +114,7 @@ where &mut context.session, txn_number, Retryability::None, + context.effective_criteria.clone(), ) .await; @@ -135,6 +136,7 @@ where &mut context.session, txn_number, Retryability::None, + context.effective_criteria.clone(), ) .await; } diff --git a/src/operation/raw_output.rs b/src/operation/raw_output.rs index ef725a26c..b3ece677e 100644 --- a/src/operation/raw_output.rs +++ b/src/operation/raw_output.rs @@ -68,6 +68,10 @@ impl Operation for RawOutput { self.0.update_for_retry() } + fn override_criteria(&self) -> super::OverrideCriteriaFn { + self.0.override_criteria() + } + fn pinned_connection(&self) -> Option<&crate::cmap::conn::PinnedConnectionHandle> { self.0.pinned_connection() } diff --git a/src/operation/run_cursor_command.rs b/src/operation/run_cursor_command.rs index b675aca16..781d84d7e 100644 --- a/src/operation/run_cursor_command.rs +++ b/src/operation/run_cursor_command.rs @@ -79,6 +79,10 @@ impl Operation for RunCursorCommand<'_> { self.run_command.update_for_retry() } + fn override_criteria(&self) -> super::OverrideCriteriaFn { + self.run_command.override_criteria() + } + fn pinned_connection(&self) -> Option<&PinnedConnectionHandle> { self.run_command.pinned_connection() } diff --git a/src/sdam/description/topology.rs b/src/sdam/description/topology.rs index a3e1e699d..2131893b9 100644 --- a/src/sdam/description/topology.rs +++ b/src/sdam/description/topology.rs @@ -205,7 +205,7 @@ impl TopologyDescription { &self, address: &ServerAddress, command: &mut Command, - criteria: Option<&SelectionCriteria>, + criteria: &SelectionCriteria, ) { let server_type = self .get_server_description(address) @@ -220,8 +220,7 @@ impl TopologyDescription { } (TopologyType::Single, ServerType::Standalone) => {} (TopologyType::Single, _) => { - let specified_read_pref = - criteria.and_then(SelectionCriteria::as_read_pref).cloned(); + let specified_read_pref = criteria.as_read_pref().cloned(); let resolved_read_pref = match specified_read_pref { Some(ReadPreference::Primary) | None => ReadPreference::PrimaryPreferred { @@ -235,11 +234,10 @@ impl TopologyDescription { } _ => { let read_pref = match criteria { - Some(SelectionCriteria::ReadPreference(rp)) => rp.clone(), - Some(SelectionCriteria::Predicate(_)) => ReadPreference::PrimaryPreferred { + SelectionCriteria::ReadPreference(rp) => rp.clone(), + SelectionCriteria::Predicate(_) => ReadPreference::PrimaryPreferred { options: Default::default(), }, - None => ReadPreference::Primary, }; if read_pref != ReadPreference::Primary { command.set_read_preference(read_pref) @@ -251,10 +249,10 @@ impl TopologyDescription { fn update_command_read_pref_for_mongos( &self, command: &mut Command, - criteria: Option<&SelectionCriteria>, + criteria: &SelectionCriteria, ) { let read_preference = match criteria { - Some(SelectionCriteria::ReadPreference(rp)) => rp, + SelectionCriteria::ReadPreference(rp) => rp, _ => return, }; match read_preference { diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index 94ee0c0bc..66ff07129 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -200,7 +200,7 @@ impl Topology { &self, server_address: &ServerAddress, command: &mut Command, - criteria: Option<&SelectionCriteria>, + criteria: &SelectionCriteria, ) { self.watcher .peek_latest() diff --git a/src/test/spec/crud.rs b/src/test/spec/crud.rs index 6fc1fdb4c..1ec16555c 100644 --- a/src/test/spec/crud.rs +++ b/src/test/spec/crud.rs @@ -43,13 +43,6 @@ async fn run_unified() { pre-5.0 server", "Requesting unacknowledged write with verboseResults is a client-side error", "Requesting unacknowledged write with ordered is a client-side error", - // TODO RUST-663: Unskip these tests. - "Aggregate with $out includes read preference for 5.0+ server", - "Aggregate with $out omits read preference for pre-5.0 server", - "Aggregate with $merge includes read preference for 5.0+ server", - "Aggregate with $merge omits read preference for pre-5.0 server", - "Database-level aggregate with $out omits read preference for pre-5.0 server", - "Database-level aggregate with $merge omits read preference for pre-5.0 server", ]; // TODO: remove this manual skip when this test is fixed to skip on serverless if *SERVERLESS {