Skip to content

Commit f1a7035

Browse files
committed
RUST-658 Change estimatedDocumentCount() to use the $collStats agg stage
1 parent 8fd04d2 commit f1a7035

File tree

38 files changed

+262
-103
lines changed

38 files changed

+262
-103
lines changed

src/client/executor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ impl Client {
322322
handler.handle_command_succeeded_event(command_succeeded_event);
323323
});
324324

325-
op.handle_response(response)
325+
op.handle_response(response, connection.stream_description()?)
326326
}
327327
}
328328
}

src/coll/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ where
265265
resolve_options!(self, options, [read_concern, selection_criteria]);
266266

267267
let op = Count::new(self.namespace(), options);
268+
268269
self.client().execute_operation(op, None).await
269270
}
270271

src/operation/aggregate/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ impl Operation for Aggregate {
6363
))
6464
}
6565

66-
fn handle_response(&self, response: CommandResponse) -> Result<Self::O> {
66+
fn handle_response(
67+
&self,
68+
response: CommandResponse,
69+
_description: &StreamDescription,
70+
) -> Result<Self::O> {
6771
let body: CursorBody = response.body()?;
6872

6973
if self.is_out_or_merge() {

src/operation/aggregate/test.rs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ async fn handle_success() {
200200
"ok": 1.0
201201
};
202202

203-
let result = aggregate.handle_response(CommandResponse::with_document_and_address(
204-
address.clone(),
205-
response.clone(),
206-
));
203+
let result = aggregate.handle_response(
204+
CommandResponse::with_document_and_address(address.clone(), response.clone()),
205+
&Default::default(),
206+
);
207207
assert!(result.is_ok());
208208

209209
let cursor_spec = result.unwrap();
@@ -228,10 +228,10 @@ async fn handle_success() {
228228
.build(),
229229
),
230230
);
231-
let result = aggregate.handle_response(CommandResponse::with_document_and_address(
232-
address.clone(),
233-
response,
234-
));
231+
let result = aggregate.handle_response(
232+
CommandResponse::with_document_and_address(address.clone(), response),
233+
&Default::default(),
234+
);
235235
assert!(result.is_ok());
236236

237237
let cursor_spec = result.unwrap();
@@ -266,7 +266,7 @@ async fn handle_max_await_time() {
266266
let aggregate = Aggregate::empty();
267267

268268
let spec = aggregate
269-
.handle_response(response.clone())
269+
.handle_response(response.clone(), &Default::default())
270270
.expect("handle should succeed");
271271
assert!(spec.max_time().is_none());
272272

@@ -276,7 +276,7 @@ async fn handle_max_await_time() {
276276
.build();
277277
let aggregate = Aggregate::new(Namespace::empty(), Vec::new(), Some(options));
278278
let spec = aggregate
279-
.handle_response(response)
279+
.handle_response(response, &Default::default())
280280
.expect("handle_should_succeed");
281281
assert_eq!(spec.max_time(), Some(max_await));
282282
}
@@ -308,7 +308,7 @@ async fn handle_write_concern_error() {
308308
);
309309

310310
let error = aggregate
311-
.handle_response(response)
311+
.handle_response(response, &Default::default())
312312
.expect_err("should get wc error");
313313
match error.kind {
314314
ErrorKind::WriteError(WriteFailure::WriteConcernError(_)) => {}
@@ -323,7 +323,7 @@ async fn handle_invalid_response() {
323323

324324
let garbled = doc! { "asdfasf": "ASdfasdf" };
325325
assert!(aggregate
326-
.handle_response(CommandResponse::with_document(garbled))
326+
.handle_response(CommandResponse::with_document(garbled), &Default::default())
327327
.is_err());
328328

329329
let missing_cursor_field = doc! {
@@ -333,6 +333,9 @@ async fn handle_invalid_response() {
333333
}
334334
};
335335
assert!(aggregate
336-
.handle_response(CommandResponse::with_document(missing_cursor_field))
336+
.handle_response(
337+
CommandResponse::with_document(missing_cursor_field),
338+
&Default::default()
339+
)
337340
.is_err());
338341
}

src/operation/count/mod.rs

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ mod test;
44
use serde::Deserialize;
55

66
use crate::{
7-
bson::{doc, Document},
7+
bson::doc,
88
cmap::{Command, CommandResponse, StreamDescription},
99
coll::{options::EstimatedDocumentCountOptions, Namespace},
10-
error::Result,
11-
operation::{append_options, Operation, Retryability},
10+
error::{Error, ErrorKind, Result},
11+
operation::{append_options, CursorBody, Operation, Retryability},
1212
selection_criteria::SelectionCriteria,
1313
};
1414

15+
const SERVER_4_9_0_WIRE_VERSION: i32 = 12;
16+
1517
pub(crate) struct Count {
1618
ns: Namespace,
1719
options: Option<EstimatedDocumentCountOptions>,
@@ -38,9 +40,30 @@ impl Operation for Count {
3840
type O = i64;
3941
const NAME: &'static str = "count";
4042

41-
fn build(&self, _description: &StreamDescription) -> Result<Command> {
42-
let mut body: Document = doc! {
43-
Self::NAME: self.ns.coll.clone(),
43+
fn build(&self, description: &StreamDescription) -> Result<Command> {
44+
let mut body = match description.max_wire_version {
45+
Some(v) if v >= SERVER_4_9_0_WIRE_VERSION => {
46+
doc! {
47+
"aggregate": self.ns.coll.clone(),
48+
"pipeline": [
49+
{
50+
"$collStats": { "count": {} },
51+
},
52+
{
53+
"$group": {
54+
"_id": 1,
55+
"n": { "$sum": "$count" },
56+
},
57+
},
58+
],
59+
"cursor": {},
60+
}
61+
}
62+
_ => {
63+
doc! {
64+
Self::NAME: self.ns.coll.clone(),
65+
}
66+
}
4467
};
4568

4669
append_options(&mut body, self.options.as_ref())?;
@@ -52,8 +75,37 @@ impl Operation for Count {
5275
))
5376
}
5477

55-
fn handle_response(&self, response: CommandResponse) -> Result<Self::O> {
56-
response.body::<ResponseBody>().map(|body| body.n)
78+
fn handle_response(
79+
&self,
80+
response: CommandResponse,
81+
description: &StreamDescription,
82+
) -> Result<Self::O> {
83+
let response_body: ResponseBody = match description.max_wire_version {
84+
Some(v) if v >= SERVER_4_9_0_WIRE_VERSION => {
85+
let CursorBody { mut cursor } = response.body()?;
86+
87+
cursor
88+
.first_batch
89+
.pop_front()
90+
.and_then(|doc| bson::from_document(doc).ok())
91+
.ok_or_else(|| {
92+
Error::from(ErrorKind::ResponseError {
93+
message: "invalid server response to count operation".into(),
94+
})
95+
})?
96+
}
97+
_ => response.body()?,
98+
};
99+
100+
Ok(response_body.n)
101+
}
102+
103+
fn handle_error(&self, error: Error) -> Result<Self::O> {
104+
if error.is_ns_not_found() {
105+
Ok(0)
106+
} else {
107+
Err(error)
108+
}
57109
}
58110

59111
fn selection_criteria(&self) -> Option<&SelectionCriteria> {

src/operation/count/test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async fn handle_success() {
8282
let response = CommandResponse::with_document(doc! { "n" : n, "ok" : 1 });
8383

8484
let actual_values = count_op
85-
.handle_response(response)
85+
.handle_response(response, &Default::default())
8686
.expect("supposed to succeed");
8787

8888
assert_eq!(actual_values, n);
@@ -95,7 +95,7 @@ async fn handle_response_no_n() {
9595

9696
let response = CommandResponse::with_document(doc! { "ok" : 1 });
9797

98-
let result = count_op.handle_response(response);
98+
let result = count_op.handle_response(response, &Default::default());
9999
match result.as_ref().map_err(|e| &e.kind) {
100100
Err(ErrorKind::ResponseError { .. }) => {}
101101
other => panic!("expected response error, but got {:?}", other),

src/operation/count_documents/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ impl Operation for CountDocuments {
7171
self.aggregate.build(description)
7272
}
7373

74-
fn handle_response(&self, response: CommandResponse) -> Result<Self::O> {
74+
fn handle_response(
75+
&self,
76+
response: CommandResponse,
77+
description: &StreamDescription,
78+
) -> Result<Self::O> {
7579
let result = self
7680
.aggregate
77-
.handle_response(response)
81+
.handle_response(response, description)
7882
.map(|mut spec| spec.initial_buffer.pop_front())?;
7983

8084
let result_doc = match result {

src/operation/count_documents/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async fn handle_success() {
117117
});
118118

119119
let actual_values = count_op
120-
.handle_response(response)
120+
.handle_response(response, &Default::default())
121121
.expect("supposed to succeed");
122122

123123
assert_eq!(actual_values, n);

src/operation/create/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ impl Operation for Create {
5050
))
5151
}
5252

53-
fn handle_response(&self, response: CommandResponse) -> Result<Self::O> {
53+
fn handle_response(
54+
&self,
55+
response: CommandResponse,
56+
_description: &StreamDescription,
57+
) -> Result<Self::O> {
5458
response.body::<WriteConcernOnlyBody>()?.validate()
5559
}
5660

src/operation/create/test.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ async fn handle_success() {
8080
let op = Create::empty();
8181

8282
let ok_response = CommandResponse::with_document(doc! { "ok": 1.0 });
83-
assert!(op.handle_response(ok_response).is_ok());
83+
assert!(op.handle_response(ok_response, &Default::default()).is_ok());
8484
let ok_extra = CommandResponse::with_document(doc! { "ok": 1.0, "hello": "world" });
85-
assert!(op.handle_response(ok_extra).is_ok());
85+
assert!(op.handle_response(ok_extra, &Default::default()).is_ok());
8686
}
8787

8888
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
@@ -99,7 +99,7 @@ async fn handle_write_concern_error() {
9999
"ok": 1
100100
});
101101

102-
let result = op.handle_response(response);
102+
let result = op.handle_response(response, &Default::default());
103103
assert!(result.is_err());
104104

105105
match result.unwrap_err().kind {

0 commit comments

Comments
 (0)