Skip to content

Commit 8d37bc8

Browse files
committed
Handle group by.
1 parent 3ff8dd2 commit 8d37bc8

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

django_mongodb/compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from django.db.models import Count, Expression
44
from django.db.models.aggregates import Aggregate
55
from django.db.models.constants import LOOKUP_SEP
6+
from django.db.models.expressions import Value
67
from django.db.models.sql import compiler
78
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI
89
from django.utils.functional import cached_property
@@ -139,7 +140,7 @@ def build_query(self, columns=None):
139140
query = self.query_class(self)
140141
query.project_fields = self.get_project_fields(columns)
141142
query.lookup_pipeline = self.get_lookup_pipeline()
142-
query.annotation_stage = self.get_group_pipeline()
143+
query.aggregation_stage = self.get_aggregation_pipeline()
143144
try:
144145
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
145146
except FullResultSet:
@@ -238,7 +239,7 @@ def get_lookup_pipeline(self):
238239
result += self.query.alias_map[alias].as_mql(self, self.connection)
239240
return result
240241

241-
def get_group_pipeline(self):
242+
def get_aggregation_pipeline(self):
242243
pipeline = None
243244
if any(isinstance(a, Aggregate) for a in self.query.annotations.values()):
244245
result = {}

django_mongodb/functions.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.db import NotSupportedError
22
from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
3-
from django.db.models.expressions import Func
3+
from django.db.models.expressions import Case, Func, Value, When
44
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
55
from django.db.models.functions.datetime import (
66
Extract,
@@ -41,7 +41,7 @@
4141
Min: "min",
4242
StdDev: "stddev",
4343
Sum: "sum",
44-
Variance: "variance",
44+
Variance: "stdDevPop",
4545
}
4646

4747

@@ -69,8 +69,19 @@
6969
}
7070

7171

72-
def aggregate(self, compiler, connection): # noqa: ARG001
73-
pass
72+
def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001
73+
if self.filter:
74+
copy = self.copy()
75+
copy.filter = None
76+
source_expressions = copy.get_source_expressions()
77+
condition = When(self.filter, then=source_expressions[0])
78+
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
79+
node = copy
80+
else:
81+
node = self
82+
lhs_mql = process_lhs(node, compiler, connection)
83+
operator = MONGO_AGGREGATION.get(self.__class__)
84+
return {f"${operator}": lhs_mql}
7485

7586

7687
def cast(self, compiler, connection):
@@ -103,6 +114,19 @@ def cot(self, compiler, connection):
103114
return {"$divide": [1, {"$tan": lhs_mql}]}
104115

105116

117+
def count(self, compiler, connection, **extra_context): # noqa: ARG001
118+
if self.filter:
119+
copy = self.copy()
120+
copy.filter = None
121+
source_expressions = copy.get_source_expressions()
122+
condition = When(self.filter, then=Value(1))
123+
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
124+
lhs_mql = process_lhs(copy, compiler, connection)
125+
else:
126+
lhs_mql = Value(1).as_mql(compiler, connection)
127+
return {"$sum": lhs_mql}
128+
129+
106130
def extract(self, compiler, connection):
107131
lhs_mql = process_lhs(self, compiler, connection)
108132
operator = EXTRACT_OPERATORS.get(self.lookup_name)
@@ -208,6 +232,7 @@ def register_functions():
208232
Concat.as_mql = concat
209233
ConcatPair.as_mql = concat_pair
210234
Cot.as_mql = cot
235+
Count.as_mql = count
211236
Extract.as_mql = extract
212237
Func.as_mql = func
213238
Left.as_mql = left

django_mongodb/query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, compiler):
5050
self.lookup_pipeline = None
5151
self.annotation_stage = None
5252
self.project_fields = None
53+
self.aggregation_stage = None
5354

5455
def __repr__(self):
5556
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -105,8 +106,8 @@ def get_cursor(self):
105106
pipeline.extend(self.lookup_pipeline)
106107
if self.mongo_query:
107108
pipeline.append({"$match": self.mongo_query})
108-
if self.annotation_stage:
109-
pipeline.extend(self.annotation_stage)
109+
if self.aggregation_stage:
110+
pipeline.extend(self.aggregation_stage)
110111
if self.project_fields:
111112
pipeline.append({"$project": self.project_fields})
112113
if self.ordering:

0 commit comments

Comments
 (0)