Skip to content

Commit 27defc3

Browse files
committed
Support distinct keyword.
1 parent d5c3f7a commit 27defc3

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

django_mongodb/compiler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,22 @@ def pre_sql_setup(self, with_col_aliases=False):
4545
aggregation_idx += 1
4646
else:
4747
alias = target
48-
group[alias] = sub_expr.as_mql(self, self.connection)
48+
4949
column_target = expr.output_field.__class__()
5050
column_target.db_column = alias
5151
column_target.set_attributes_from_name(alias)
52-
replacements[sub_expr] = Col(self.collection_name, column_target)
52+
inner_column = Col(self.collection_name, column_target)
53+
if sub_expr.distinct:
54+
inner_expr = sub_expr.as_mql(self, self.connection, force_filters=True)
55+
rhs = next(iter(inner_expr.values()))
56+
group[alias] = {"$addToSet": rhs}
57+
replacing_expr = sub_expr.copy()
58+
replacing_expr.set_source_expressions([inner_column])
59+
else:
60+
group[alias] = sub_expr.as_mql(self, self.connection)
61+
replacing_expr = inner_column
62+
63+
replacements[sub_expr] = replacing_expr
5364
result_expr = expr.replace_expressions(replacements)
5465
all_replacements.update(replacements)
5566
group_expressions |= set(expr.get_group_by_cols())
@@ -93,7 +104,7 @@ def _ccc(col):
93104
}
94105
)
95106
pipeline = []
96-
if ids is None:
107+
if not ids:
97108
group["_id"] = None
98109
pipeline.append({"$facet": {"group": [{"$group": group}]}})
99110
pipeline.append(

django_mongodb/functions.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from copy import deepcopy
2+
13
from django.db import NotSupportedError
24
from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
35
from django.db.models.expressions import Case, Func, Star, Value, When
@@ -32,6 +34,8 @@
3234
Trim,
3335
Upper,
3436
)
37+
from django.db.models.lookups import Exact
38+
from django.db.models.query_utils import Q
3539

3640
from .query_utils import process_lhs
3741

@@ -115,21 +119,34 @@ def cot(self, compiler, connection):
115119
return {"$divide": [1, {"$tan": lhs_mql}]}
116120

117121

118-
def count(self, compiler, connection, **extra_context): # noqa: ARG001
119-
if self.filter:
120-
copy = self.copy()
121-
copy.filter = None
122-
source_expressions = copy.get_source_expressions()
123-
condition = When(self.filter, then=Value(1))
124-
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
125-
node = copy
126-
cond = process_lhs(node, compiler, connection)
127-
else:
128-
node = self
129-
lhs_mql = process_lhs(self, compiler, connection)
130-
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
131-
cond = {"$cond": {"if": null_cond, "then": 0, "else": 1}}
132-
return {"$sum": cond}
122+
def count(self, compiler, connection, **extra_context):
123+
if not self.distinct or extra_context.get("force_filters"):
124+
if self.filter:
125+
copy = self.copy()
126+
copy.filter = None
127+
source_expressions = copy.get_source_expressions()
128+
filter_ = deepcopy(self.filter)
129+
filter_.add(~Q(Exact(source_expressions[0], Value(None))), filter_.default)
130+
condition = When(filter_, then=Value(1))
131+
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
132+
node = copy
133+
cond = process_lhs(node, compiler, connection)
134+
else:
135+
node = self
136+
lhs_mql = process_lhs(self, compiler, connection)
137+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
138+
cond = {
139+
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
140+
}
141+
operator = "$sum"
142+
return {operator: cond}
143+
144+
operator = "$size"
145+
lhs_mql = process_lhs(self, compiler, connection)
146+
147+
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
148+
149+
return {"$add": [{"$size": lhs_mql}, exits_null]}
133150

134151

135152
def extract(self, compiler, connection):

0 commit comments

Comments
 (0)