Skip to content

Commit b870bfd

Browse files
authored
RUST-830 Consolidate error label storage to Error::labels (#350)
1 parent 906efae commit b870bfd

File tree

9 files changed

+82
-94
lines changed

9 files changed

+82
-94
lines changed

src/cmap/conn/command.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use serde::de::DeserializeOwned;
1+
use serde::{de::DeserializeOwned, Deserialize};
22

33
use super::wire::Message;
44
use crate::{
55
bson::{Bson, Document},
66
bson_util,
77
client::{options::ServerApi, ClusterTime},
8-
error::{CommandError, ErrorKind, Result},
8+
error::{CommandError, Error, ErrorKind, Result},
99
options::ServerAddress,
1010
selection_criteria::ReadPreference,
1111
ClientSession,
@@ -136,16 +136,19 @@ impl CommandResponse {
136136
}
137137
}
138138

139-
/// Retunrs a result indicating whether this response corresponds to a command failure.
139+
/// Returns a result indicating whether this response corresponds to a command failure.
140140
pub(crate) fn validate(&self) -> Result<()> {
141141
if !self.is_success() {
142-
let command_error: CommandError =
142+
let error_response: CommandErrorResponse =
143143
bson::from_bson(Bson::Document(self.raw_response.clone())).map_err(|_| {
144144
ErrorKind::InvalidResponse {
145145
message: "invalid server response".to_string(),
146146
}
147147
})?;
148-
Err(ErrorKind::Command(command_error).into())
148+
Err(Error::new(
149+
ErrorKind::Command(error_response.command_error),
150+
error_response.error_labels,
151+
))
149152
} else {
150153
Ok(())
151154
}
@@ -172,3 +175,12 @@ impl CommandResponse {
172175
&self.source
173176
}
174177
}
178+
179+
#[derive(Deserialize, Debug)]
180+
struct CommandErrorResponse {
181+
#[serde(rename = "errorLabels")]
182+
error_labels: Option<Vec<String>>,
183+
184+
#[serde(flatten)]
185+
command_error: CommandError,
186+
}

src/coll/mod.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
mod batch;
22
pub mod options;
33

4-
use std::{borrow::Borrow, fmt, fmt::Debug, sync::Arc};
4+
use std::{borrow::Borrow, collections::HashSet, fmt, fmt::Debug, sync::Arc};
55

66
use futures_util::stream::StreamExt;
77
use serde::{
8-
de::{DeserializeOwned, Error},
8+
de::{DeserializeOwned, Error as DeError},
99
Deserialize,
1010
Deserializer,
1111
Serialize,
@@ -17,7 +17,7 @@ use crate::{
1717
bson_util,
1818
client::session::TransactionState,
1919
concern::{ReadConcern, WriteConcern},
20-
error::{convert_bulk_errors, BulkWriteError, BulkWriteFailure, ErrorKind, Result},
20+
error::{convert_bulk_errors, BulkWriteError, BulkWriteFailure, Error, ErrorKind, Result},
2121
operation::{
2222
Aggregate,
2323
Count,
@@ -689,6 +689,7 @@ where
689689
let ordered = options.as_ref().and_then(|o| o.ordered).unwrap_or(true);
690690

691691
let mut cumulative_failure: Option<BulkWriteFailure> = None;
692+
let mut error_labels: HashSet<String> = Default::default();
692693
let mut cumulative_result: Option<InsertManyResult> = None;
693694

694695
let mut n_attempted = 0;
@@ -736,11 +737,15 @@ where
736737
failure_ref.write_concern_error = Some(write_concern_error.clone());
737738
}
738739

740+
error_labels.extend(e.labels);
741+
739742
if ordered {
740-
return Err(ErrorKind::BulkWrite(
741-
cumulative_failure.unwrap_or_else(BulkWriteFailure::new),
742-
)
743-
.into());
743+
return Err(Error::new(
744+
ErrorKind::BulkWrite(
745+
cumulative_failure.unwrap_or_else(BulkWriteFailure::new),
746+
),
747+
Some(error_labels),
748+
));
744749
}
745750
}
746751
_ => return Err(e),
@@ -749,7 +754,10 @@ where
749754
}
750755

751756
match cumulative_failure {
752-
Some(failure) => Err(ErrorKind::BulkWrite(failure).into()),
757+
Some(failure) => Err(Error::new(
758+
ErrorKind::BulkWrite(failure),
759+
Some(error_labels),
760+
)),
753761
None => Ok(cumulative_result.unwrap_or_else(InsertManyResult::new)),
754762
}
755763
}

src/error.rs

Lines changed: 16 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
//! Contains the `Error` and `Result` types that `mongodb` uses.
22
3-
use std::{
4-
fmt::{self, Debug},
5-
sync::Arc,
6-
};
3+
use std::{collections::HashSet, fmt::{self, Debug}, sync::Arc};
74

85
use serde::Deserialize;
96
use thiserror::Error;
@@ -43,10 +40,17 @@ pub type Result<T> = std::result::Result<T, Error>;
4340
pub struct Error {
4441
/// The type of error that occurred.
4542
pub kind: Box<ErrorKind>,
46-
labels: Vec<String>,
43+
pub(crate) labels: HashSet<String>,
4744
}
4845

4946
impl Error {
47+
pub(crate) fn new(kind: ErrorKind, labels: Option<impl IntoIterator<Item=String>>) -> Self {
48+
Self {
49+
kind: Box::new(kind),
50+
labels: labels.map(|labels| labels.into_iter().collect()).unwrap_or_default(),
51+
}
52+
}
53+
5054
pub(crate) fn pool_cleared_error(address: &ServerAddress) -> Self {
5155
ErrorKind::ConnectionPoolCleared {
5256
message: format!(
@@ -159,19 +163,8 @@ impl Error {
159163
}
160164

161165
/// Returns the labels for this error.
162-
pub fn labels(&self) -> &[String] {
163-
match self.kind.as_ref() {
164-
ErrorKind::Command(ref err) => &err.labels,
165-
ErrorKind::Write(ref err) => match err {
166-
WriteFailure::WriteError(_) => &self.labels,
167-
WriteFailure::WriteConcernError(ref err) => &err.labels,
168-
},
169-
ErrorKind::BulkWrite(ref err) => match err.write_concern_error {
170-
Some(ref err) => &err.labels,
171-
None => &self.labels,
172-
},
173-
_ => &self.labels,
174-
}
166+
pub fn labels(&self) -> &HashSet<String> {
167+
&self.labels
175168
}
176169

177170
/// Whether this error contains the specified label.
@@ -184,30 +177,7 @@ impl Error {
184177
/// Adds the given label to this error.
185178
pub(crate) fn add_label<T: AsRef<str>>(&mut self, label: T) {
186179
let label = label.as_ref().to_string();
187-
match self.kind.as_mut() {
188-
ErrorKind::Command(err) => {
189-
err.labels.push(label);
190-
}
191-
ErrorKind::Write(err) => match err {
192-
WriteFailure::WriteError(_) => {
193-
self.labels.push(label);
194-
}
195-
WriteFailure::WriteConcernError(err) => {
196-
err.labels.push(label);
197-
}
198-
},
199-
ErrorKind::BulkWrite(err) => match err.write_concern_error.as_mut() {
200-
Some(write_concern_error) => {
201-
write_concern_error.labels.push(label);
202-
}
203-
None => {
204-
self.labels.push(label);
205-
}
206-
},
207-
_ => {
208-
self.labels.push(label);
209-
}
210-
}
180+
self.labels.insert(label);
211181
}
212182

213183
pub(crate) fn from_resolve_error(error: trust_dns_resolver::error::ResolveError) -> Self {
@@ -325,7 +295,7 @@ where
325295
fn from(err: E) -> Self {
326296
Self {
327297
kind: Box::new(err.into()),
328-
labels: Vec::new(),
298+
labels: Default::default(),
329299
}
330300
}
331301
}
@@ -447,10 +417,6 @@ pub struct CommandError {
447417
/// A description of the error that occurred.
448418
#[serde(rename = "errmsg")]
449419
pub message: String,
450-
451-
/// The error labels that the server returned.
452-
#[serde(rename = "errorLabels", default)]
453-
pub labels: Vec<String>,
454420
}
455421

456422
impl fmt::Display for CommandError {
@@ -477,10 +443,6 @@ pub struct WriteConcernError {
477443
/// A document identifying the write concern setting related to the error.
478444
#[serde(rename = "errInfo")]
479445
pub details: Option<Document>,
480-
481-
/// The error labels that the server returned.
482-
#[serde(rename = "errorLabels", default)]
483-
pub labels: Vec<String>,
484446
}
485447

486448
/// An error that occurred during a write operation that wasn't due to being unable to satisfy a
@@ -584,9 +546,9 @@ impl WriteFailure {
584546
/// untouched.
585547
pub(crate) fn convert_bulk_errors(error: Error) -> Error {
586548
match *error.kind {
587-
ErrorKind::BulkWrite(ref bulk_failure) => {
588-
match WriteFailure::from_bulk_failure(bulk_failure.clone()) {
589-
Ok(failure) => ErrorKind::Write(failure).into(),
549+
ErrorKind::BulkWrite(bulk_failure) => {
550+
match WriteFailure::from_bulk_failure(bulk_failure) {
551+
Ok(failure) => Error::new(ErrorKind::Write(failure), Some(error.labels)),
590552
Err(e) => e,
591553
}
592554
}

src/operation/delete/test.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ async fn handle_write_concern_failure() {
192192
"wtimeout": 0,
193193
"provenance": "clientSupplied"
194194
} }),
195-
labels: Vec::new(),
196195
};
197196
assert_eq!(wc_error, &expected_wc_err);
198197
}

src/operation/insert/test.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ async fn handle_write_failure() {
219219
"wtimeout": 0,
220220
"provenance": "clientSupplied"
221221
} }),
222-
labels: Vec::new(),
223222
};
224223
assert_eq!(write_concern_error, expected_wc_err);
225224
}

src/operation/mod.rs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,18 @@ struct EmptyBody {}
148148
struct WriteConcernOnlyBody {
149149
#[serde(rename = "writeConcernError")]
150150
write_concern_error: Option<WriteConcernError>,
151+
152+
#[serde(rename = "errorLabels")]
153+
labels: Option<Vec<String>>,
151154
}
152155

153156
impl WriteConcernOnlyBody {
154157
fn validate(&self) -> Result<()> {
155158
match self.write_concern_error {
156-
Some(ref wc_error) => {
157-
Err(ErrorKind::Write(WriteFailure::WriteConcernError(wc_error.clone())).into())
158-
}
159+
Some(ref wc_error) => Err(Error::new(
160+
ErrorKind::Write(WriteFailure::WriteConcernError(wc_error.clone())),
161+
self.labels.clone(),
162+
)),
159163
None => Ok(()),
160164
}
161165
}
@@ -184,24 +188,15 @@ impl<T> WriteResponseBody<T> {
184188
return Ok(());
185189
};
186190

187-
// Error labels for WriteConcernErrors are sent from the server in a separate field.
188-
let write_concern_error = match self.write_concern_error {
189-
Some(ref write_concern_error) => {
190-
let mut write_concern_error = write_concern_error.clone();
191-
if let Some(ref labels) = self.labels {
192-
write_concern_error.labels.append(&mut labels.clone());
193-
}
194-
Some(write_concern_error)
195-
}
196-
None => None,
197-
};
198-
199191
let failure = BulkWriteFailure {
200192
write_errors: self.write_errors.clone(),
201-
write_concern_error,
193+
write_concern_error: self.write_concern_error.clone(),
202194
};
203195

204-
Err(ErrorKind::BulkWrite(failure).into())
196+
Err(Error::new(
197+
ErrorKind::BulkWrite(failure),
198+
self.labels.clone(),
199+
))
205200
}
206201
}
207202

src/operation/update/test.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ async fn handle_write_concern_failure() {
276276
"wtimeout": 0,
277277
"provenance": "clientSupplied"
278278
} }),
279-
labels: Vec::new(),
280279
};
281280
assert_eq!(wc_error, &expected_wc_err);
282281
}

src/test/spec/retryable_writes/mod.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod test_file;
22

33
use std::time::Duration;
44

5+
use bson::Bson;
56
use futures::stream::TryStreamExt;
67
use semver::VersionReq;
78
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
@@ -27,7 +28,7 @@ use crate::{
2728
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
2829
async fn run_spec_tests() {
2930
async fn run_test(test_file: TestFile) {
30-
for test_case in test_file.tests {
31+
for mut test_case in test_file.tests {
3132
if test_case.operation.name == "bulkWrite" {
3233
continue;
3334
}
@@ -63,7 +64,15 @@ async fn run_spec_tests() {
6364
.expect(&test_case.description);
6465
}
6566

66-
if let Some(ref fail_point) = test_case.fail_point {
67+
if let Some(ref mut fail_point) = test_case.fail_point {
68+
// TODO: DRIVERS-1385 remove this logic for moving errorLabels to the top level.
69+
if let Some(Bson::Document(data)) = fail_point.get_mut("data") {
70+
if let Some(Bson::Document(wc_error)) = data.get_mut("writeConcernError") {
71+
if let Some(labels) = wc_error.remove("errorLabels") {
72+
data.insert("errorLabels", labels);
73+
}
74+
}
75+
}
6776
client
6877
.database("admin")
6978
.run_command(fail_point.clone(), None)
@@ -89,9 +98,14 @@ async fn run_spec_tests() {
8998
match expected_result {
9099
Result::Value(value) => {
91100
let description = &test_case.description;
92-
let result = result.unwrap().unwrap_or_else(|| {
93-
panic!("{:?}: operation should succeed", description)
94-
});
101+
let result = result
102+
.unwrap_or_else(|e| {
103+
panic!(
104+
"{:?}: operation should succeed, got error: {}",
105+
description, e
106+
)
107+
})
108+
.unwrap();
95109
assert_matches(&result, &value, Some(description));
96110
}
97111
Result::Labels(expected_labels) => {

src/test/spec/v2_runner/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,13 @@ pub async fn run_v2_test(test_file: TestFile) {
259259
assert_eq!(error_code_name, code_name);
260260
}
261261
if let Some(error_labels_contain) = operation_error.error_labels_contain {
262-
let labels = error.labels().to_vec();
262+
let labels = error.labels();
263263
error_labels_contain
264264
.iter()
265265
.for_each(|label| assert!(labels.contains(label)));
266266
}
267267
if let Some(error_labels_omit) = operation_error.error_labels_omit {
268-
let labels = error.labels().to_vec();
268+
let labels = error.labels();
269269
error_labels_omit
270270
.iter()
271271
.for_each(|label| assert!(!labels.contains(label)));

0 commit comments

Comments
 (0)