Skip to content

RUST-2078 Support QE with bulk write #1445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
17 changes: 2 additions & 15 deletions src/action/bulk_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,6 @@ where
}

async fn execute_inner(mut self) -> Result<R> {
#[cfg(feature = "in-use-encryption")]
if self.client.should_auto_encrypt().await {
use mongocrypt::error::{Error as EncryptionError, ErrorKind as EncryptionErrorKind};

let error = EncryptionError {
kind: EncryptionErrorKind::Client,
code: None,
message: Some(
"bulkWrite does not currently support automatic encryption".to_string(),
),
};
return Err(ErrorKind::Encryption(error).into());
}

resolve_write_concern_with_session!(
self.client,
self.options,
Expand All @@ -148,7 +134,8 @@ where
&self.models[total_attempted..],
total_attempted,
self.options.as_ref(),
);
)
.await;
let result = self
.client
.execute_operation::<BulkWriteOperation<R>>(
Expand Down
45 changes: 45 additions & 0 deletions src/bson_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use crate::{
RawBsonRef,
RawDocumentBuf,
},
bson_compat::CStr,
checked::Checked,
cmap::Command,
error::{Error, ErrorKind, Result},
runtime::SyncLittleEndianRead,
};
Expand Down Expand Up @@ -246,6 +248,49 @@ pub(crate) fn get_or_prepend_id_field(doc: &mut RawDocumentBuf) -> Result<Bson>
}
}

/// A helper trait for working with collections of raw documents. This is useful for unifying
/// command-building implementations that conditionally construct both document sequences and a
/// single command document.
pub(crate) trait RawDocumentCollection: Default {
/// Calculates the total number of bytes that would be added to a collection of this type by the
/// given document.
fn bytes_added(index: usize, doc: &RawDocumentBuf) -> Result<usize>;

/// Adds the given document to the collection.
fn push(&mut self, doc: RawDocumentBuf);

/// Adds the collection of raw documents to the provided command.
fn add_to_command(self, identifier: &'static CStr, command: &mut Command);
}

impl RawDocumentCollection for Vec<RawDocumentBuf> {
fn bytes_added(_index: usize, doc: &RawDocumentBuf) -> Result<usize> {
Ok(doc.as_bytes().len())
}

fn push(&mut self, doc: RawDocumentBuf) {
self.push(doc);
}

fn add_to_command(self, identifier: &'static CStr, command: &mut Command) {
command.add_document_sequence(identifier, self);
}
}

impl RawDocumentCollection for RawArrayBuf {
fn bytes_added(index: usize, doc: &RawDocumentBuf) -> Result<usize> {
array_entry_size_bytes(index, doc.as_bytes().len())
}

fn push(&mut self, doc: RawDocumentBuf) {
self.push(doc);
}

fn add_to_command(self, identifier: &'static CStr, command: &mut Command) {
command.body.append(identifier, self);
}
}

#[cfg(test)]
mod test {
use crate::bson_util::num_decimal_digits;
Expand Down
3 changes: 2 additions & 1 deletion src/client/csfle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ impl ClientState {
.kms_providers(&opts.kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.retry_kms(true)?
.use_range_v2()?;
.use_range_v2()?
.use_need_mongo_collinfo_with_db_state();
if let Some(m) = &opts.schema_map {
builder = builder.schema_map(&crate::bson_compat::serialize_to_document(m)?)?;
}
Expand Down
33 changes: 18 additions & 15 deletions src/client/csfle/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use std::{
time::Duration,
};

use crate::{
bson::{rawdoc, Document, RawDocument, RawDocumentBuf},
bson_compat::{cstr, CString},
};
use crate::bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
use futures_util::{stream, TryStreamExt};
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
use rayon::ThreadPool;
Expand Down Expand Up @@ -95,6 +92,13 @@ impl CryptExecutor {
self.mongocryptd_client.is_some()
}

fn metadata_client(&self, state: &State) -> Result<Client> {
self.metadata_client
.as_ref()
.and_then(|w| w.upgrade())
.ok_or_else(|| Error::internal(format!("metadata client required for {state:?}")))
}

pub(crate) async fn run_ctx(&self, ctx: Ctx, db: Option<&str>) -> Result<RawDocumentBuf> {
let mut result = None;
// This needs to be a `Result` so that the `Ctx` can be temporarily owned by the processing
Expand All @@ -104,16 +108,10 @@ impl CryptExecutor {
loop {
let state = result_ref(&ctx)?.state()?;
match state {
State::NeedMongoCollinfo => {
State::NeedMongoCollinfo | State::NeedMongoCollinfoWithDb => {
let ctx = result_mut(&mut ctx)?;
let filter = raw_to_doc(ctx.mongo_op()?)?;
let metadata_client = self
.metadata_client
.as_ref()
.and_then(|w| w.upgrade())
.ok_or_else(|| {
Error::internal("metadata_client required for NeedMongoCollinfo state")
})?;
let metadata_client = self.metadata_client(&state)?;
let db = metadata_client.database(db.as_ref().ok_or_else(|| {
Error::internal("db required for NeedMongoCollinfo state")
})?);
Expand Down Expand Up @@ -245,7 +243,9 @@ impl CryptExecutor {
continue;
}

let prov_name: CString = provider.as_string().try_into()?;
#[cfg(any(feature = "aws-auth", feature = "azure-kms"))]
let prov_name: crate::bson_compat::CString =
provider.as_string().try_into()?;
match provider.provider_type() {
KmsProviderType::Aws => {
#[cfg(feature = "aws-auth")]
Expand All @@ -263,7 +263,10 @@ impl CryptExecutor {
"secretAccessKey": aws_creds.secret_access_key().to_string(),
};
if let Some(token) = aws_creds.session_token() {
creds.append(cstr!("sessionToken"), token);
creds.append(
crate::bson_compat::cstr!("sessionToken"),
token,
);
}
kms_providers.append(prov_name, creds);
}
Expand Down Expand Up @@ -326,7 +329,7 @@ impl CryptExecutor {
.await
.map_err(|e| kms_error(e.to_string()))?;
kms_providers.append(
cstr!("gcp"),
crate::bson_compat::cstr!("gcp"),
rawdoc! { "accessToken": response.access_token },
);
}
Expand Down
48 changes: 28 additions & 20 deletions src/client/options/bulk_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ use typed_builder::TypedBuilder;

use crate::{
bson::{rawdoc, Array, Bson, Document, RawDocumentBuf},
bson_compat::cstr,
bson_util::{get_or_prepend_id_field, replacement_document_check, update_document_check},
error::Result,
bson_compat::{cstr, serialize_to_raw_document_buf},
bson_util::{
extend_raw_document_buf,
get_or_prepend_id_field,
replacement_document_check,
update_document_check,
},
error::{Error, Result},
options::{UpdateModifications, WriteConcern},
serde_util::{serialize_bool_or_true, write_concern_is_empty},
Collection,
Expand Down Expand Up @@ -371,9 +376,15 @@ impl WriteModel {
}
}

/// Returns the operation-specific fields that should be included in this model's entry in the
/// ops array. Also returns an inserted ID if this is an insert operation.
pub(crate) fn get_ops_document_contents(&self) -> Result<(RawDocumentBuf, Option<Bson>)> {
/// Constructs the ops document for this write model given the nsInfo array index.
pub(crate) fn get_ops_document(
&self,
ns_info_index: usize,
) -> Result<(RawDocumentBuf, Option<Bson>)> {
let index = i32::try_from(ns_info_index)
.map_err(|_| Error::internal("nsInfo index exceeds i32::MAX"))?;
let mut ops_document = rawdoc! { self.operation_name(): index };

if let Self::UpdateOne(UpdateOneModel { update, .. })
| Self::UpdateMany(UpdateManyModel { update, .. }) = self
{
Expand All @@ -384,22 +395,19 @@ impl WriteModel {
replacement_document_check(replacement)?;
}

let (mut model_document, inserted_id) = match self {
Self::InsertOne(model) => {
let mut insert_document = RawDocumentBuf::try_from(&model.document)?;
let inserted_id = get_or_prepend_id_field(&mut insert_document)?;
(rawdoc! { "document": insert_document }, Some(inserted_id))
}
_ => {
let model_document = crate::bson_compat::serialize_to_raw_document_buf(&self)?;
(model_document, None)
}
};

if let Some(multi) = self.multi() {
model_document.append(cstr!("multi"), multi);
ops_document.append(cstr!("multi"), multi);
}

Ok((model_document, inserted_id))
if let Self::InsertOne(model) = self {
let mut insert_document = RawDocumentBuf::try_from(&model.document)?;
let inserted_id = get_or_prepend_id_field(&mut insert_document)?;
ops_document.append(cstr!("document"), insert_document);
Ok((ops_document, Some(inserted_id)))
} else {
let model = serialize_to_raw_document_buf(&self)?;
extend_raw_document_buf(&mut ops_document, model)?;
Ok((ops_document, None))
}
}
}
10 changes: 8 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ pub type Result<T> = std::result::Result<T, Error>;
/// [`ErrorKind`](enum.ErrorKind.html) is wrapped in an `Box` to allow the errors to be
/// cloned.
#[derive(Clone, Debug, Error)]
#[cfg_attr(test, error("Kind: {kind}, labels: {labels:?}, backtrace: {bt}"))]
#[cfg_attr(not(test), error("Kind: {kind}, labels: {labels:?}"))]
#[cfg_attr(
test,
error("Kind: {kind}, labels: {labels:?}, source: {source:?}, backtrace: {bt}")
)]
#[cfg_attr(
not(test),
error("Kind: {kind}, labels: {labels:?}, source: {source:?}")
)]
#[non_exhaustive]
pub struct Error {
/// The type of error that occurred.
Expand Down
Loading