Skip to content

Commit 0359a69

Browse files
committed
move aggregates to separate file
1 parent e2e16b9 commit 0359a69

File tree

4 files changed

+84
-74
lines changed

4 files changed

+84
-74
lines changed

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
check_django_compatability()
88

9+
from .aggregates import register_aggregates # noqa: E402
910
from .expressions import register_expressions # noqa: E402
1011
from .fields import register_fields # noqa: E402
1112
from .functions import register_functions # noqa: E402
1213
from .lookups import register_lookups # noqa: E402
1314
from .query import register_nodes # noqa: E402
1415

16+
register_aggregates()
1517
register_expressions()
1618
register_fields()
1719
register_functions()

django_mongodb/aggregates.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from copy import deepcopy
2+
3+
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
4+
from django.db.models.expressions import Case, Value, When
5+
from django.db.models.lookups import Exact
6+
from django.db.models.sql.where import WhereNode
7+
8+
from .query_utils import process_lhs
9+
10+
MONGO_AGGREGATIONS = {
11+
Count: "sum",
12+
StdDev: "stdDev", # Samp or Pop suffix added in aggregate().
13+
Variance: "stdDev", # Likewise.
14+
}
15+
16+
17+
def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001
18+
if self.filter:
19+
node = self.copy()
20+
node.filter = None
21+
source_expressions = node.get_source_expressions()
22+
condition = When(self.filter, then=source_expressions[0])
23+
node.set_source_expressions([Case(condition)] + source_expressions[1:])
24+
else:
25+
node = self
26+
lhs_mql = process_lhs(node, compiler, connection)
27+
operator = MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
28+
# Add suffixes to StdDev/Variance.
29+
if self.function.endswith("_SAMP"):
30+
operator += "Samp"
31+
elif self.function.endswith("_POP"):
32+
operator += "Pop"
33+
return {f"${operator}": lhs_mql}
34+
35+
36+
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
37+
"""
38+
When resolve_inner_expression is True, return the argument as MQL that
39+
resolves as a value. This is used to count different elements, so the inner
40+
values are returned to be pushed into a set.
41+
"""
42+
if not self.distinct or resolve_inner_expression:
43+
if self.filter:
44+
node = self.copy()
45+
node.filter = None
46+
source_expressions = node.get_source_expressions()
47+
filter_ = deepcopy(self.filter)
48+
filter_.add(
49+
WhereNode([Exact(source_expressions[0], Value(None))], negated=True),
50+
filter_.default,
51+
)
52+
condition = When(filter_, then=Value(1))
53+
node.set_source_expressions([Case(condition)] + source_expressions[1:])
54+
inner_expression = process_lhs(node, compiler, connection)
55+
else:
56+
lhs_mql = process_lhs(self, compiler, connection)
57+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
58+
inner_expression = {
59+
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
60+
}
61+
if resolve_inner_expression:
62+
return inner_expression
63+
return {"$sum": inner_expression}
64+
# If distinct=True or resolve_inner_expression=False, sum the size
65+
# of the set.
66+
lhs_mql = process_lhs(self, compiler, connection)
67+
# Subtract 1 if None is in the set (it shouldn't have been counted).
68+
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
69+
return {"$add": [{"$size": lhs_mql}, exits_null]}
70+
71+
72+
def register_aggregates():
73+
Aggregate.as_mql = aggregate
74+
Count.as_mql = count

django_mongodb/expressions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NegatedExpression,
1414
Ref,
1515
ResolvedOuterRef,
16+
Star,
1617
Subquery,
1718
Value,
1819
When,
@@ -79,6 +80,10 @@ def ref(self, compiler, connection): # noqa: ARG001
7980
return f"${self.refs}"
8081

8182

83+
def star(self, compiler, connection): # noqa: ARG001
84+
return {"$literal": True}
85+
86+
8287
def subquery(self, compiler, connection): # noqa: ARG001
8388
raise NotSupportedError(f"{self.__class__.__name__} is not supported on MongoDB.")
8489

@@ -113,6 +118,7 @@ def register_expressions():
113118
Query.as_mql = query
114119
Ref.as_mql = ref
115120
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
121+
Star.as_mql = star
116122
Subquery.as_mql = subquery
117123
When.as_mql = when
118124
Value.as_mql = value

django_mongodb/functions.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from copy import deepcopy
2-
31
from django.db import NotSupportedError
4-
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
5-
from django.db.models.expressions import Case, Func, Star, Value, When
6-
from django.db.models.functions import Now
2+
from django.db.models.expressions import Func
73
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
84
from django.db.models.functions.datetime import (
95
Extract,
@@ -17,6 +13,7 @@
1713
ExtractWeek,
1814
ExtractWeekDay,
1915
ExtractYear,
16+
Now,
2017
TruncBase,
2118
)
2219
from django.db.models.functions.math import Ceil, Cot, Degrees, Log, Power, Radians, Random, Round
@@ -34,16 +31,9 @@
3431
Trim,
3532
Upper,
3633
)
37-
from django.db.models.lookups import Exact
38-
from django.db.models.sql.where import WhereNode
3934

4035
from .query_utils import process_lhs
4136

42-
MONGO_AGGREGATIONS = {
43-
Count: "sum",
44-
StdDev: "stdDev", # Samp or Pop suffix added in aggregate().
45-
Variance: "stdDev", # Likewise.
46-
}
4737
MONGO_OPERATORS = {
4838
Ceil: "ceil",
4939
Coalesce: "ifNull",
@@ -68,25 +58,6 @@
6858
}
6959

7060

71-
def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001
72-
if self.filter:
73-
node = self.copy()
74-
node.filter = None
75-
source_expressions = node.get_source_expressions()
76-
condition = When(self.filter, then=source_expressions[0])
77-
node.set_source_expressions([Case(condition)] + source_expressions[1:])
78-
else:
79-
node = self
80-
lhs_mql = process_lhs(node, compiler, connection)
81-
operator = MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
82-
# Add suffixes to StdDev/Variance.
83-
if self.function.endswith("_SAMP"):
84-
operator += "Samp"
85-
elif self.function.endswith("_POP"):
86-
operator += "Pop"
87-
return {f"${operator}": lhs_mql}
88-
89-
9061
def cast(self, compiler, connection):
9162
output_type = connection.data_types[self.output_field.get_internal_type()]
9263
lhs_mql = process_lhs(self, compiler, connection)[0]
@@ -117,42 +88,6 @@ def cot(self, compiler, connection):
11788
return {"$divide": [1, {"$tan": lhs_mql}]}
11889

11990

120-
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
121-
"""
122-
When resolve_inner_expression is True, return the argument as MQL that
123-
resolves as a value. This is used to count different elements, so the inner
124-
values are returned to be pushed into a set.
125-
"""
126-
if not self.distinct or resolve_inner_expression:
127-
if self.filter:
128-
node = self.copy()
129-
node.filter = None
130-
source_expressions = node.get_source_expressions()
131-
filter_ = deepcopy(self.filter)
132-
filter_.add(
133-
WhereNode([Exact(source_expressions[0], Value(None))], negated=True),
134-
filter_.default,
135-
)
136-
condition = When(filter_, then=Value(1))
137-
node.set_source_expressions([Case(condition)] + source_expressions[1:])
138-
inner_expression = process_lhs(node, compiler, connection)
139-
else:
140-
lhs_mql = process_lhs(self, compiler, connection)
141-
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
142-
inner_expression = {
143-
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
144-
}
145-
if resolve_inner_expression:
146-
return inner_expression
147-
return {"$sum": inner_expression}
148-
# If distinct=True or resolve_inner_expression=False, sum the size
149-
# of the set.
150-
lhs_mql = process_lhs(self, compiler, connection)
151-
# Subtract 1 if None is in the set (it shouldn't have been counted).
152-
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
153-
return {"$add": [{"$size": lhs_mql}, exits_null]}
154-
155-
15691
def extract(self, compiler, connection):
15792
lhs_mql = process_lhs(self, compiler, connection)
15893
operator = EXTRACT_OPERATORS.get(self.lookup_name)
@@ -223,10 +158,6 @@ def round_(self, compiler, connection):
223158
return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]}
224159

225160

226-
def star(self, compiler, connection): # noqa: ARG001
227-
return {"$literal": True}
228-
229-
230161
def str_index(self, compiler, connection):
231162
lhs = process_lhs(self, compiler, connection)
232163
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
@@ -261,12 +192,10 @@ def trunc(self, compiler, connection):
261192

262193

263194
def register_functions():
264-
Aggregate.as_mql = aggregate
265195
Cast.as_mql = cast
266196
Concat.as_mql = concat
267197
ConcatPair.as_mql = concat_pair
268198
Cot.as_mql = cot
269-
Count.as_mql = count
270199
Extract.as_mql = extract
271200
Func.as_mql = func
272201
Left.as_mql = left
@@ -279,7 +208,6 @@ def register_functions():
279208
Replace.as_mql = replace
280209
Round.as_mql = round_
281210
RTrim.as_mql = trim("rtrim")
282-
Star.as_mql = star
283211
StrIndex.as_mql = str_index
284212
Substr.as_mql = substr
285213
Trim.as_mql = trim("trim")

0 commit comments

Comments
 (0)