Skip to content

Commit 6f409db

Browse files
WaVEVtimgraham
authored andcommitted
add support for QuerySet.aggregate()
Also add support for Count() in QuerySet.annotate().
1 parent 9bca50f commit 6f409db

File tree

10 files changed

+453
-161
lines changed

10 files changed

+453
-161
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
- name: Run tests
6868
run: >
6969
python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2
70+
aggregation
7071
annotations
7172
auth_tests.test_models.UserManagerTestCase
7273
backends.base.test_base.DatabaseWrapperTests

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ Migrations for 'admin':
109109
## Known issues and limitations
110110

111111
- The following `QuerySet` methods aren't supported:
112-
- `aggregate()`
113112
- `bulk_update()`
114113
- `dates()`
115114
- `datetimes()`

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
check_django_compatability()
88

9+
from .aggregates import register_aggregates # noqa: E402
910
from .expressions import register_expressions # noqa: E402
1011
from .fields import register_fields # noqa: E402
1112
from .functions import register_functions # noqa: E402
1213
from .lookups import register_lookups # noqa: E402
1314
from .query import register_nodes # noqa: E402
1415

16+
register_aggregates()
1517
register_expressions()
1618
register_fields()
1719
register_functions()

django_mongodb/aggregates.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from copy import deepcopy
2+
3+
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
4+
from django.db.models.expressions import Case, Value, When
5+
from django.db.models.lookups import Exact
6+
from django.db.models.sql.where import WhereNode
7+
8+
from .query_utils import process_lhs
9+
10+
# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
11+
MONGO_AGGREGATIONS = {Count: "sum"}
12+
13+
14+
def aggregate(
15+
self,
16+
compiler,
17+
connection,
18+
operator=None,
19+
resolve_inner_expression=False,
20+
**extra_context, # noqa: ARG001
21+
):
22+
if self.filter:
23+
node = self.copy()
24+
node.filter = None
25+
source_expressions = node.get_source_expressions()
26+
condition = When(self.filter, then=source_expressions[0])
27+
node.set_source_expressions([Case(condition)] + source_expressions[1:])
28+
else:
29+
node = self
30+
lhs_mql = process_lhs(node, compiler, connection)
31+
if resolve_inner_expression:
32+
return lhs_mql
33+
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
34+
return {f"${operator}": lhs_mql}
35+
36+
37+
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
38+
"""
39+
When resolve_inner_expression=True, return the MQL that resolves as a
40+
value. This is used to count different elements, so the inner values are
41+
returned to be pushed into a set.
42+
"""
43+
if not self.distinct or resolve_inner_expression:
44+
if self.filter:
45+
node = self.copy()
46+
node.filter = None
47+
source_expressions = node.get_source_expressions()
48+
filter_ = deepcopy(self.filter)
49+
filter_.add(
50+
WhereNode([Exact(source_expressions[0], Value(None))], negated=True),
51+
filter_.default,
52+
)
53+
condition = When(filter_, then=Value(1))
54+
node.set_source_expressions([Case(condition)] + source_expressions[1:])
55+
inner_expression = process_lhs(node, compiler, connection)
56+
else:
57+
lhs_mql = process_lhs(self, compiler, connection)
58+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
59+
inner_expression = {
60+
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
61+
}
62+
if resolve_inner_expression:
63+
return inner_expression
64+
return {"$sum": inner_expression}
65+
# If distinct=True or resolve_inner_expression=False, sum the size of the
66+
# set.
67+
lhs_mql = process_lhs(self, compiler, connection)
68+
# None shouldn't be counted, so subtract 1 if it's present.
69+
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
70+
return {"$add": [{"$size": lhs_mql}, exits_null]}
71+
72+
73+
def stddev_variance(self, compiler, connection, **extra_context):
74+
if self.function.endswith("_SAMP"):
75+
operator = "stdDevSamp"
76+
elif self.function.endswith("_POP"):
77+
operator = "stdDevPop"
78+
return aggregate(self, compiler, connection, operator=operator, **extra_context)
79+
80+
81+
def register_aggregates():
82+
Aggregate.as_mql = aggregate
83+
Count.as_mql = count
84+
StdDev.as_mql = stddev_variance
85+
Variance.as_mql = stddev_variance

0 commit comments

Comments
 (0)