Skip to content

Commit 330ced1

Browse files
RUST-2078 Support QE with bulk write (#1445)
1 parent d3a752d commit 330ced1

File tree

17 files changed

+1120
-179
lines changed

17 files changed

+1120
-179
lines changed

src/action/bulk_write.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,6 @@ where
117117
}
118118

119119
async fn execute_inner(mut self) -> Result<R> {
120-
#[cfg(feature = "in-use-encryption")]
121-
if self.client.should_auto_encrypt().await {
122-
use mongocrypt::error::{Error as EncryptionError, ErrorKind as EncryptionErrorKind};
123-
124-
let error = EncryptionError {
125-
kind: EncryptionErrorKind::Client,
126-
code: None,
127-
message: Some(
128-
"bulkWrite does not currently support automatic encryption".to_string(),
129-
),
130-
};
131-
return Err(ErrorKind::Encryption(error).into());
132-
}
133-
134120
resolve_write_concern_with_session!(
135121
self.client,
136122
self.options,
@@ -148,7 +134,8 @@ where
148134
&self.models[total_attempted..],
149135
total_attempted,
150136
self.options.as_ref(),
151-
);
137+
)
138+
.await;
152139
let result = self
153140
.client
154141
.execute_operation::<BulkWriteOperation<R>>(

src/bson_util.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use crate::{
1919
RawBsonRef,
2020
RawDocumentBuf,
2121
},
22+
bson_compat::CStr,
2223
checked::Checked,
24+
cmap::Command,
2325
error::{Error, ErrorKind, Result},
2426
runtime::SyncLittleEndianRead,
2527
};
@@ -246,6 +248,49 @@ pub(crate) fn get_or_prepend_id_field(doc: &mut RawDocumentBuf) -> Result<Bson>
246248
}
247249
}
248250

251+
/// A helper trait for working with collections of raw documents. This is useful for unifying
252+
/// command-building implementations that conditionally construct either document sequences or a
253+
/// single command document.
254+
pub(crate) trait RawDocumentCollection: Default {
255+
/// Calculates the total number of bytes that would be added to a collection of this type by the
256+
/// given document.
257+
fn bytes_added(index: usize, doc: &RawDocumentBuf) -> Result<usize>;
258+
259+
/// Adds the given document to the collection.
260+
fn push(&mut self, doc: RawDocumentBuf);
261+
262+
/// Adds the collection of raw documents to the provided command.
263+
fn add_to_command(self, identifier: &CStr, command: &mut Command);
264+
}
265+
266+
impl RawDocumentCollection for Vec<RawDocumentBuf> {
267+
fn bytes_added(_index: usize, doc: &RawDocumentBuf) -> Result<usize> {
268+
Ok(doc.as_bytes().len())
269+
}
270+
271+
fn push(&mut self, doc: RawDocumentBuf) {
272+
self.push(doc);
273+
}
274+
275+
fn add_to_command(self, identifier: &CStr, command: &mut Command) {
276+
command.add_document_sequence(identifier, self);
277+
}
278+
}
279+
280+
impl RawDocumentCollection for RawArrayBuf {
281+
fn bytes_added(index: usize, doc: &RawDocumentBuf) -> Result<usize> {
282+
array_entry_size_bytes(index, doc.as_bytes().len())
283+
}
284+
285+
fn push(&mut self, doc: RawDocumentBuf) {
286+
self.push(doc);
287+
}
288+
289+
fn add_to_command(self, identifier: &CStr, command: &mut Command) {
290+
command.body.append(identifier, self);
291+
}
292+
}
293+
249294
#[cfg(test)]
250295
mod test {
251296
use crate::bson_util::num_decimal_digits;

src/client/csfle.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ impl ClientState {
100100
.kms_providers(&opts.kms_providers.credentials_doc()?)?
101101
.use_need_kms_credentials_state()
102102
.retry_kms(true)?
103-
.use_range_v2()?;
103+
.use_range_v2()?
104+
.use_need_mongo_collinfo_with_db_state();
104105
if let Some(m) = &opts.schema_map {
105106
builder = builder.schema_map(&crate::bson_compat::serialize_to_document(m)?)?;
106107
}

src/client/csfle/state_machine.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@ use std::{
55
time::Duration,
66
};
77

8-
use crate::{
9-
bson::{rawdoc, Document, RawDocument, RawDocumentBuf},
10-
bson_compat::{cstr, CString},
11-
};
8+
use crate::bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
129
use futures_util::{stream, TryStreamExt};
1310
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
1411
use rayon::ThreadPool;
@@ -95,6 +92,13 @@ impl CryptExecutor {
9592
self.mongocryptd_client.is_some()
9693
}
9794

95+
fn metadata_client(&self, state: &State) -> Result<Client> {
96+
self.metadata_client
97+
.as_ref()
98+
.and_then(|w| w.upgrade())
99+
.ok_or_else(|| Error::internal(format!("metadata client required for {state:?}")))
100+
}
101+
98102
pub(crate) async fn run_ctx(&self, ctx: Ctx, db: Option<&str>) -> Result<RawDocumentBuf> {
99103
let mut result = None;
100104
// This needs to be a `Result` so that the `Ctx` can be temporarily owned by the processing
@@ -104,16 +108,10 @@ impl CryptExecutor {
104108
loop {
105109
let state = result_ref(&ctx)?.state()?;
106110
match state {
107-
State::NeedMongoCollinfo => {
111+
State::NeedMongoCollinfo | State::NeedMongoCollinfoWithDb => {
108112
let ctx = result_mut(&mut ctx)?;
109113
let filter = raw_to_doc(ctx.mongo_op()?)?;
110-
let metadata_client = self
111-
.metadata_client
112-
.as_ref()
113-
.and_then(|w| w.upgrade())
114-
.ok_or_else(|| {
115-
Error::internal("metadata_client required for NeedMongoCollinfo state")
116-
})?;
114+
let metadata_client = self.metadata_client(&state)?;
117115
let db = metadata_client.database(db.as_ref().ok_or_else(|| {
118116
Error::internal("db required for NeedMongoCollinfo state")
119117
})?);
@@ -245,7 +243,9 @@ impl CryptExecutor {
245243
continue;
246244
}
247245

248-
let prov_name: CString = provider.as_string().try_into()?;
246+
#[cfg(any(feature = "aws-auth", feature = "azure-kms"))]
247+
let prov_name: crate::bson_compat::CString =
248+
provider.as_string().try_into()?;
249249
match provider.provider_type() {
250250
KmsProviderType::Aws => {
251251
#[cfg(feature = "aws-auth")]
@@ -263,7 +263,10 @@ impl CryptExecutor {
263263
"secretAccessKey": aws_creds.secret_access_key().to_string(),
264264
};
265265
if let Some(token) = aws_creds.session_token() {
266-
creds.append(cstr!("sessionToken"), token);
266+
creds.append(
267+
crate::bson_compat::cstr!("sessionToken"),
268+
token,
269+
);
267270
}
268271
kms_providers.append(prov_name, creds);
269272
}
@@ -326,7 +329,7 @@ impl CryptExecutor {
326329
.await
327330
.map_err(|e| kms_error(e.to_string()))?;
328331
kms_providers.append(
329-
cstr!("gcp"),
332+
crate::bson_compat::cstr!("gcp"),
330333
rawdoc! { "accessToken": response.access_token },
331334
);
332335
}

src/client/options/bulk_write.rs

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ use typed_builder::TypedBuilder;
77

88
use crate::{
99
bson::{rawdoc, Array, Bson, Document, RawDocumentBuf},
10-
bson_compat::cstr,
11-
bson_util::{get_or_prepend_id_field, replacement_document_check, update_document_check},
12-
error::Result,
10+
bson_compat::{cstr, serialize_to_raw_document_buf},
11+
bson_util::{
12+
extend_raw_document_buf,
13+
get_or_prepend_id_field,
14+
replacement_document_check,
15+
update_document_check,
16+
},
17+
error::{Error, Result},
1318
options::{UpdateModifications, WriteConcern},
1419
serde_util::{serialize_bool_or_true, write_concern_is_empty},
1520
Collection,
@@ -371,9 +376,17 @@ impl WriteModel {
371376
}
372377
}
373378

374-
/// Returns the operation-specific fields that should be included in this model's entry in the
375-
/// ops array. Also returns an inserted ID if this is an insert operation.
376-
pub(crate) fn get_ops_document_contents(&self) -> Result<(RawDocumentBuf, Option<Bson>)> {
379+
/// Constructs the ops document for this write model given the nsInfo array index.
380+
pub(crate) fn get_ops_document(
381+
&self,
382+
ns_info_index: usize,
383+
) -> Result<(RawDocumentBuf, Option<Bson>)> {
384+
// The maximum number of namespaces allowed in a bulkWrite command is much lower than
385+
// i32::MAX, so this should never fail.
386+
let index = i32::try_from(ns_info_index)
387+
.map_err(|_| Error::internal("nsInfo index exceeds i32::MAX"))?;
388+
let mut ops_document = rawdoc! { self.operation_name(): index };
389+
377390
if let Self::UpdateOne(UpdateOneModel { update, .. })
378391
| Self::UpdateMany(UpdateManyModel { update, .. }) = self
379392
{
@@ -384,22 +397,19 @@ impl WriteModel {
384397
replacement_document_check(replacement)?;
385398
}
386399

387-
let (mut model_document, inserted_id) = match self {
388-
Self::InsertOne(model) => {
389-
let mut insert_document = RawDocumentBuf::try_from(&model.document)?;
390-
let inserted_id = get_or_prepend_id_field(&mut insert_document)?;
391-
(rawdoc! { "document": insert_document }, Some(inserted_id))
392-
}
393-
_ => {
394-
let model_document = crate::bson_compat::serialize_to_raw_document_buf(&self)?;
395-
(model_document, None)
396-
}
397-
};
398-
399400
if let Some(multi) = self.multi() {
400-
model_document.append(cstr!("multi"), multi);
401+
ops_document.append(cstr!("multi"), multi);
401402
}
402403

403-
Ok((model_document, inserted_id))
404+
if let Self::InsertOne(model) = self {
405+
let mut insert_document = RawDocumentBuf::try_from(&model.document)?;
406+
let inserted_id = get_or_prepend_id_field(&mut insert_document)?;
407+
ops_document.append(cstr!("document"), insert_document);
408+
Ok((ops_document, Some(inserted_id)))
409+
} else {
410+
let model = serialize_to_raw_document_buf(&self)?;
411+
extend_raw_document_buf(&mut ops_document, model)?;
412+
Ok((ops_document, None))
413+
}
404414
}
405415
}

src/error.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,14 @@ pub type Result<T> = std::result::Result<T, Error>;
5050
/// [`ErrorKind`](enum.ErrorKind.html) is wrapped in an `Box` to allow the errors to be
5151
/// cloned.
5252
#[derive(Clone, Debug, Error)]
53-
#[cfg_attr(test, error("Kind: {kind}, labels: {labels:?}, backtrace: {bt}"))]
54-
#[cfg_attr(not(test), error("Kind: {kind}, labels: {labels:?}"))]
53+
#[cfg_attr(
54+
test,
55+
error("Kind: {kind}, labels: {labels:?}, source: {source:?}, backtrace: {bt}")
56+
)]
57+
#[cfg_attr(
58+
not(test),
59+
error("Kind: {kind}, labels: {labels:?}, source: {source:?}")
60+
)]
5561
#[non_exhaustive]
5662
pub struct Error {
5763
/// The type of error that occurred.

0 commit comments

Comments
 (0)