Skip to content

Commit 926078c

Browse files
WaVEVtimgraham
authored andcommitted
Initial commit.
1 parent 9bca50f commit 926078c

File tree

4 files changed

+100
-28
lines changed

4 files changed

+100
-28
lines changed

django_mongodb/compiler.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def check_query(self):
143143
isinstance(a, Aggregate) and not isinstance(a, Count)
144144
for a in self.query.annotations.values()
145145
):
146-
raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.")
146+
# raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.")
147+
pass
147148

148149
def get_count(self, check_exists=False):
149150
"""
@@ -170,8 +171,10 @@ def get_count(self, check_exists=False):
170171
def build_query(self, columns=None):
171172
"""Check if the query is supported and prepare a MongoQuery."""
172173
self.check_query()
173-
query = self.query_class(self, columns)
174+
query = self.query_class(self)
175+
query.project_fields = self._get_project_fields(columns)
174176
query.lookup_pipeline = self.get_lookup_pipeline()
177+
query.annotation_stage = self._get_group_pipeline()
175178
try:
176179
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
177180
except FullResultSet:
@@ -276,6 +279,73 @@ def get_lookup_pipeline(self):
276279
result += self.query.alias_map[alias].as_mql(self, self.connection)
277280
return result
278281

282+
def _get_group_pipeline(self):
283+
pipeline = None
284+
if any(isinstance(a, Aggregate) for a in self.query.annotations.values()):
285+
result = {}
286+
# self.get_group_by(self.select, [])
287+
for alias, annotation in self.query.annotation_select.items():
288+
value = annotation.as_mql(self, self.connection)
289+
if isinstance(value, list):
290+
value = value[0]
291+
result[alias] = value
292+
293+
expressions = set()
294+
for expr, *_ in self.select:
295+
expressions |= set(expr.get_group_by_cols())
296+
order_by = self.get_order_by()
297+
for expr, (_, _, is_ref) in order_by:
298+
# Skip references to the SELECT clause, as all expressions in
299+
# the SELECT clause are already part of the GROUP BY.
300+
if not is_ref:
301+
expressions |= set(expr.get_group_by_cols())
302+
having_group_by = self.having.get_group_by_cols() if self.having else ()
303+
for expr in having_group_by:
304+
expressions.add(expr)
305+
306+
ids = (
307+
None
308+
if not expressions
309+
else {col.target.column: col.as_mql(self, self.connection) for col in expressions}
310+
)
311+
result["_id"] = ids
312+
313+
pipeline = [{"$group": result}]
314+
if ids:
315+
pipeline.append(
316+
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
317+
)
318+
if "_id" not in ids:
319+
pipeline.append({"$unSet": "$_id"})
320+
321+
return pipeline
322+
323+
def _get_project_fields(self, columns=None):
324+
fields = {}
325+
for name, expr in columns or []:
326+
try:
327+
column = name if isinstance(expr, Aggregate) else expr.target.column
328+
except AttributeError:
329+
# Generate the MQL for an annotation.
330+
try:
331+
fields[name] = expr.as_mql(self, self.connection)
332+
except EmptyResultSet:
333+
fields[name] = Value(False).as_mql(self, self.connection)
334+
except FullResultSet:
335+
fields[name] = Value(True).as_mql(self, self.connection)
336+
else:
337+
# If name != column, then this is an annotatation referencing
338+
# another column.
339+
fields[name] = 1 if name == column else f"${column}"
340+
341+
if fields:
342+
# Add related fields.
343+
for alias in self.query.alias_map:
344+
if self.query.alias_refcount[alias] and self.collection_name != alias:
345+
fields[alias] = 1
346+
347+
return fields
348+
279349

280350
class SQLInsertCompiler(SQLCompiler):
281351
def execute_sql(self, returning_fields=None):

django_mongodb/functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.db import NotSupportedError
2+
from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
23
from django.db.models.expressions import Func
34
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
45
from django.db.models.functions.datetime import (
@@ -33,6 +34,17 @@
3334

3435
from .query_utils import process_lhs
3536

37+
MONGO_AGGREGATION = {
38+
Avg: "avg",
39+
Count: "sum",
40+
Max: "max",
41+
Min: "min",
42+
StdDev: "stddev",
43+
Sum: "sum",
44+
Variance: "variance",
45+
}
46+
47+
3648
MONGO_OPERATORS = {
3749
Ceil: "ceil",
3850
Coalesce: "ifNull",
@@ -57,6 +69,10 @@
5769
}
5870

5971

72+
def aggregate(self, compiler, connection): # noqa: ARG001
73+
pass
74+
75+
6076
def cast(self, compiler, connection):
6177
output_type = connection.data_types[self.output_field.get_internal_type()]
6278
lhs_mql = process_lhs(self, compiler, connection)[0]
@@ -187,6 +203,7 @@ def trunc(self, compiler, connection):
187203

188204

189205
def register_functions():
206+
Aggregate.as_mql = aggregate
190207
Cast.as_mql = cast
191208
Concat.as_mql = concat
192209
ConcatPair.as_mql = concat_pair

django_mongodb/query.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError
6-
from django.db.models.expressions import Case, Value, When
6+
from django.db.models.expressions import Case, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -37,18 +37,19 @@ class MongoQuery:
3737
built by Django to a "representation" more suitable for MongoDB.
3838
"""
3939

40-
def __init__(self, compiler, columns):
40+
def __init__(self, compiler):
4141
self.compiler = compiler
4242
self.connection = compiler.connection
4343
self.ops = compiler.connection.ops
4444
self.query = compiler.query
45-
self.columns = columns
4645
self._negated = False
4746
self.ordering = []
4847
self.collection = self.compiler.get_collection()
4948
self.collection_name = self.compiler.collection_name
5049
self.mongo_query = getattr(compiler.query, "raw_query", {})
5150
self.lookup_pipeline = None
51+
self.annotation_stage = None
52+
self.project_fields = None
5253

5354
def __repr__(self):
5455
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -97,35 +98,16 @@ def get_cursor(self, count=False, limit=None, skip=None):
9798
9899
Use `limit` or `skip` to override those options of the query.
99100
"""
100-
fields = {}
101-
for name, expr in self.columns or []:
102-
try:
103-
column = expr.target.column
104-
except AttributeError:
105-
# Generate the MQL for an annotation.
106-
try:
107-
fields[name] = expr.as_mql(self.compiler, self.connection)
108-
except EmptyResultSet:
109-
fields[name] = Value(False).as_mql(self.compiler, self.connection)
110-
except FullResultSet:
111-
fields[name] = Value(True).as_mql(self.compiler, self.connection)
112-
else:
113-
# If name != column, then this is an annotatation referencing
114-
# another column.
115-
fields[name] = 1 if name == column else f"${column}"
116-
if fields:
117-
# Add related fields.
118-
for alias in self.query.alias_map:
119-
if self.query.alias_refcount[alias] and self.collection_name != alias:
120-
fields[alias] = 1
121101
# Construct the query pipeline.
122102
pipeline = []
123103
if self.lookup_pipeline:
124104
pipeline.extend(self.lookup_pipeline)
125105
if self.mongo_query:
126106
pipeline.append({"$match": self.mongo_query})
127-
if fields:
128-
pipeline.append({"$project": fields})
107+
if self.annotation_stage:
108+
pipeline.extend(self.annotation_stage)
109+
if self.project_fields:
110+
pipeline.append({"$project": self.project_fields})
129111
if self.ordering:
130112
pipeline.append({"$sort": dict(self.ordering)})
131113
if skip is not None:

django_mongodb/query_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core.exceptions import FullResultSet
2+
from django.db.models.aggregates import Aggregate
23
from django.db.models.expressions import Value
34

45

@@ -15,6 +16,8 @@ def process_lhs(node, compiler, connection):
1516
result.append(expr.as_mql(compiler, connection))
1617
except FullResultSet:
1718
result.append(Value(True).as_mql(compiler, connection))
19+
if isinstance(node, Aggregate):
20+
return result[0]
1821
return result
1922
# node is a Transform with just one source expression, aliased as "lhs".
2023
if is_direct_value(node.lhs):

0 commit comments

Comments
 (0)