Skip to content

Commit 8152971

Browse files
committed
PYTHON-1699 Add database level aggregate helper
1 parent f85a9f9 commit 8152971

File tree

8 files changed

+468
-138
lines changed

8 files changed

+468
-138
lines changed

doc/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include:
4141
- Support zstandard for wire protocol compression.
4242
- Support for periodically polling DNS SRV records to update the mongos proxy
4343
list without having to change client configuration.
44+
- New method :meth:`pymongo.database.Database.aggregate` to support running
45+
database level aggregations.
4446

4547
Now that supported operations are retried automatically and transparently,
4648
users should consider adjusting any custom retry logic to prevent

pymongo/aggregation.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright 2019-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you
4+
# may not use this file except in compliance with the License. You
5+
# may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12+
# implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
"""Perform aggregation operations on a collection or database."""
16+
17+
from bson.son import SON
18+
19+
from pymongo import common
20+
from pymongo.collation import validate_collation_or_none
21+
from pymongo.errors import ConfigurationError
22+
23+
24+
class _AggregationCommand(object):
25+
"""The internal abstract base class for aggregation cursors.
26+
27+
Should not be called directly by application developers. Use
28+
:meth:`pymongo.collection.Collection.aggregate`, or
29+
:meth:`pymongo.database.Database.aggregate` instead.
30+
"""
31+
def __init__(self, target, cursor_class, pipeline, options,
32+
explicit_session, user_fields=None, result_processor=None):
33+
if "explain" in options:
34+
raise ConfigurationError("The explain option is not supported. "
35+
"Use Database.command instead.")
36+
37+
self._target = target
38+
39+
common.validate_list('pipeline', pipeline)
40+
self._pipeline = pipeline
41+
42+
common.validate_is_mapping('options', options)
43+
self._options = options
44+
45+
self._cursor_class = cursor_class
46+
self._explicit_session = explicit_session
47+
self._user_fields = user_fields
48+
self._result_processor = result_processor
49+
50+
self._collation = validate_collation_or_none(
51+
options.pop('collation', None))
52+
53+
self._max_await_time_ms = options.pop('maxAwaitTimeMS', None)
54+
self._batch_size = common.validate_non_negative_integer_or_none(
55+
"batchSize", options.pop("batchSize", None))
56+
57+
self._dollar_out = (self._pipeline and
58+
'$out' in self._pipeline[-1])
59+
60+
@property
61+
def _aggregation_target(self):
62+
"""The argument to pass to the aggregate command."""
63+
raise NotImplementedError
64+
65+
@property
66+
def _cursor_namespace(self):
67+
"""The namespace in which the aggregate command is run."""
68+
raise NotImplementedError
69+
70+
@property
71+
def _database(self):
72+
"""The database against which the aggregation command is run."""
73+
raise NotImplementedError
74+
75+
@staticmethod
76+
def _check_compat(sock_info):
77+
"""Check whether the server version in-use supports aggregation."""
78+
pass
79+
80+
def _process_result(self, result, session, server, sock_info, slave_ok):
81+
if self._result_processor:
82+
self._result_processor(
83+
result, session, server, sock_info, slave_ok)
84+
85+
def get_cursor(self, session, server, sock_info, slave_ok):
86+
# Ensure command compatibility.
87+
self._check_compat(sock_info)
88+
89+
# Serialize command.
90+
cmd = SON([("aggregate", self._aggregation_target),
91+
("pipeline", self._pipeline)])
92+
cmd.update(self._options)
93+
94+
# Cache read preference for easy access.
95+
read_preference = self._target._read_preference_for(session)
96+
97+
# Apply this target's read concern if:
98+
# readConcern has not been specified as a kwarg and either
99+
# - server version is >= 4.2 or
100+
# - server version is >= 3.2 and pipeline doesn't use $out
101+
if (('readConcern' not in cmd) and
102+
((sock_info.max_wire_version >= 4 and not self._dollar_out) or
103+
(sock_info.max_wire_version >= 8))):
104+
read_concern = self._target.read_concern
105+
else:
106+
read_concern = None
107+
108+
# Apply this target's write concern if:
109+
# writeConcern has not been specified as a kwarg and pipeline doesn't
110+
# use $out
111+
if 'writeConcern' not in cmd and self._dollar_out:
112+
write_concern = self._target._write_concern_for(session)
113+
else:
114+
write_concern = None
115+
116+
# Run command.
117+
result = sock_info.command(
118+
self._database.name,
119+
cmd,
120+
slave_ok,
121+
read_preference,
122+
self._target.codec_options,
123+
parse_write_concern_error=True,
124+
read_concern=read_concern,
125+
write_concern=write_concern,
126+
collation=self._collation,
127+
session=session,
128+
client=self._database.client,
129+
user_fields=self._user_fields)
130+
131+
self._process_result(result, session, server, sock_info, slave_ok)
132+
133+
# Extract cursor from result or mock/fake one if necessary.
134+
if 'cursor' in result:
135+
cursor = result['cursor']
136+
else:
137+
# Pre-MongoDB 2.6 or unacknowledged write. Fake a cursor.
138+
cursor = {
139+
"id": 0,
140+
"firstBatch": result.get("result", []),
141+
"ns": self._cursor_namespace,
142+
}
143+
144+
# Get collection to target with cursor.
145+
ns = cursor["ns"]
146+
_, collname = ns.split(".", 1)
147+
aggregation_collection = self._database.get_collection(
148+
collname, codec_options=self._target.codec_options,
149+
read_preference=read_preference,
150+
write_concern=self._target.write_concern,
151+
read_concern=self._target.read_concern)
152+
153+
# Create and return cursor instance.
154+
return self._cursor_class(
155+
aggregation_collection, cursor, sock_info.address,
156+
batch_size=self._batch_size or 0,
157+
max_await_time_ms=self._max_await_time_ms,
158+
session=session, explicit_session=self._explicit_session)
159+
160+
161+
class _CollectionAggregationCommand(_AggregationCommand):
162+
@property
163+
def _aggregation_target(self):
164+
return self._target.name
165+
166+
@property
167+
def _cursor_namespace(self):
168+
return self._target.full_name
169+
170+
@property
171+
def _database(self):
172+
return self._target.database
173+
174+
175+
class _DatabaseAggregationCommand(_AggregationCommand):
176+
@property
177+
def _aggregation_target(self):
178+
return 1
179+
180+
@property
181+
def _cursor_namespace(self):
182+
return "%s.%s.aggregate" % (self._target.name, "$cmd")
183+
184+
@property
185+
def _database(self):
186+
return self._target
187+
188+
@staticmethod
189+
def _check_compat(sock_info):
190+
# Older server version don't raise a descriptive error, so we raise
191+
# one instead.
192+
if not sock_info.max_wire_version >= 6:
193+
err_msg = "Database.aggregation is only supported on MongoDB 3.6+."
194+
raise ConfigurationError(err_msg)

pymongo/change_stream.py

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
from bson import _bson_to_dict
2020
from bson.raw_bson import RawBSONDocument
21-
from bson.son import SON
2221

2322
from pymongo import common
23+
from pymongo.aggregation import (_CollectionAggregationCommand,
24+
_DatabaseAggregationCommand)
2425
from pymongo.collation import validate_collation_or_none
2526
from pymongo.command_cursor import CommandCursor
2627
from pymongo.errors import (ConnectionFailure,
@@ -86,17 +87,17 @@ def __init__(self, target, pipeline, full_document, resume_after,
8687
self._cursor = self._create_cursor()
8788

8889
@property
89-
def _aggregation_target(self):
90-
"""The argument to pass to the aggregate command."""
90+
def _aggregation_command_class(self):
91+
"""The aggregation command class to be used."""
9192
raise NotImplementedError
9293

9394
@property
94-
def _database(self):
95-
"""The database against which the aggregation commands for
95+
def _client(self):
96+
"""The client against which the aggregation commands for
9697
this ChangeStream will be run. """
9798
raise NotImplementedError
9899

99-
def _pipeline_options(self):
100+
def _change_stream_options(self):
100101
options = {}
101102
if self._full_document is not None:
102103
options['fullDocument'] = self._full_document
@@ -108,69 +109,45 @@ def _pipeline_options(self):
108109
options['startAtOperationTime'] = self._start_at_operation_time
109110
return options
110111

111-
def _full_pipeline(self):
112+
def _command_options(self):
113+
options = {'cursor': {}}
114+
if self._max_await_time_ms is not None:
115+
options["maxAwaitTimeMS"] = self._max_await_time_ms
116+
return options
117+
118+
def _aggregation_pipeline(self):
112119
"""Return the full aggregation pipeline for this ChangeStream."""
113-
options = self._pipeline_options()
120+
options = self._change_stream_options()
114121
full_pipeline = [{'$changeStream': options}]
115122
full_pipeline.extend(self._pipeline)
116123
return full_pipeline
117124

118-
def _run_aggregation_cmd(self, session, explicit_session):
119-
"""Run the full aggregation pipeline for this ChangeStream and return
120-
the corresponding CommandCursor.
121-
"""
122-
read_preference = self._target._read_preference_for(session)
123-
client = self._database.client
124-
125-
def _cmd(session, server, sock_info, slave_ok):
126-
pipeline = self._full_pipeline()
127-
cmd = SON([("aggregate", self._aggregation_target),
128-
("pipeline", pipeline),
129-
("cursor", {})])
130-
131-
result = sock_info.command(
132-
self._database.name,
133-
cmd,
134-
slave_ok,
135-
read_preference,
136-
self._target.codec_options,
137-
parse_write_concern_error=True,
138-
read_concern=self._target.read_concern,
139-
collation=self._collation,
140-
session=session,
141-
client=self._database.client)
142-
143-
cursor = result["cursor"]
144-
145-
if (self._start_at_operation_time is None and
125+
def _process_result(self, result, session, server, sock_info, slave_ok):
126+
"""Callback that records a change stream cursor's operationTime."""
127+
if (self._start_at_operation_time is None and
146128
self._resume_token is None and
147129
self._start_after is None and
148130
sock_info.max_wire_version >= 7):
149-
self._start_at_operation_time = result["operationTime"]
131+
self._start_at_operation_time = result["operationTime"]
150132

151-
ns = cursor["ns"]
152-
_, collname = ns.split(".", 1)
153-
aggregation_collection = self._database.get_collection(
154-
collname, codec_options=self._target.codec_options,
155-
read_preference=read_preference,
156-
write_concern=self._target.write_concern,
157-
read_concern=self._target.read_concern
158-
)
159-
160-
return CommandCursor(
161-
aggregation_collection, cursor, sock_info.address,
162-
batch_size=self._batch_size or 0,
163-
max_await_time_ms=self._max_await_time_ms,
164-
session=session, explicit_session=explicit_session)
133+
def _run_aggregation_cmd(self, session, explicit_session):
134+
"""Run the full aggregation pipeline for this ChangeStream and return
135+
the corresponding CommandCursor.
136+
"""
137+
cmd = self._aggregation_command_class(
138+
self._target, CommandCursor, self._aggregation_pipeline(),
139+
self._command_options(), explicit_session,
140+
result_processor=self._process_result)
165141

166-
return client._retryable_read(_cmd, read_preference, session)
142+
return self._client._retryable_read(
143+
cmd.get_cursor, self._target._read_preference_for(session),
144+
session)
167145

168146
def _create_cursor(self):
169-
with self._database.client._tmp_session(self._session, close=False) as s:
147+
with self._client._tmp_session(self._session, close=False) as s:
170148
return self._run_aggregation_cmd(
171149
session=s,
172-
explicit_session=self._session is not None
173-
)
150+
explicit_session=self._session is not None)
174151

175152
def _resume(self):
176153
"""Reestablish this change stream after a resumable error."""
@@ -302,12 +279,12 @@ class CollectionChangeStream(ChangeStream):
302279
.. versionadded:: 3.7
303280
"""
304281
@property
305-
def _aggregation_target(self):
306-
return self._target.name
282+
def _aggregation_command_class(self):
283+
return _CollectionAggregationCommand
307284

308285
@property
309-
def _database(self):
310-
return self._target.database
286+
def _client(self):
287+
return self._target.database.client
311288

312289

313290
class DatabaseChangeStream(ChangeStream):
@@ -319,12 +296,12 @@ class DatabaseChangeStream(ChangeStream):
319296
.. versionadded:: 3.7
320297
"""
321298
@property
322-
def _aggregation_target(self):
323-
return 1
299+
def _aggregation_command_class(self):
300+
return _DatabaseAggregationCommand
324301

325302
@property
326-
def _database(self):
327-
return self._target
303+
def _client(self):
304+
return self._target.client
328305

329306

330307
class ClusterChangeStream(DatabaseChangeStream):
@@ -335,7 +312,7 @@ class ClusterChangeStream(DatabaseChangeStream):
335312
336313
.. versionadded:: 3.7
337314
"""
338-
def _pipeline_options(self):
339-
options = super(ClusterChangeStream, self)._pipeline_options()
315+
def _change_stream_options(self):
316+
options = super(ClusterChangeStream, self)._change_stream_options()
340317
options["allChangesForCluster"] = True
341318
return options

0 commit comments

Comments
 (0)