Skip to content

Commit 17dfa68

Browse files
committed
Add having.
1 parent 4742745 commit 17dfa68

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

django_mongodb/compiler.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ class SQLCompiler(compiler.SQLCompiler):
2121
def pre_sql_setup(self, with_col_aliases=False):
2222
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2323
self.annotations = {}
24+
# mongo_having = self.having.copy() if self.having else None
2425
group = {}
2526
group_expressions = set()
2627
aggregation_idx = 1
28+
all_replacements = {}
2729
for target, expr in self.query.annotation_select.items():
2830
if not expr.contains_aggregate:
2931
result_expr = expr
@@ -41,7 +43,7 @@ def pre_sql_setup(self, with_col_aliases=False):
4143
column_target.set_attributes_from_name(alias)
4244
replacements[sub_expr] = Col(self.collection_name, column_target)
4345
result_expr = expr.replace_expressions(replacements)
44-
46+
all_replacements.update(replacements)
4547
self.annotations[target] = result_expr
4648
if group:
4749
order_by = self.get_order_by()
@@ -61,7 +63,9 @@ def pre_sql_setup(self, with_col_aliases=False):
6163
if not group_expressions
6264
else {
6365
col.target.column: col.as_mql(self, self.connection)
66+
# expression aren't needed in the group by clouse ()
6467
for col in group_expressions
68+
if isinstance(col, Col)
6569
}
6670
)
6771
group["_id"] = ids
@@ -71,7 +75,17 @@ def pre_sql_setup(self, with_col_aliases=False):
7175
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
7276
)
7377
if "_id" not in ids:
74-
pipeline.append({"$unSet": "$_id"})
78+
pipeline.append({"$unset": "_id"})
79+
if self.having:
80+
pipeline.append(
81+
{
82+
"$match": {
83+
"$expr": self.having.replace_expressions(all_replacements).as_mql(
84+
self, self.connection
85+
)
86+
}
87+
}
88+
)
7589

7690
self._group_pipeline = pipeline
7791
else:
@@ -206,7 +220,9 @@ def build_query(self, columns=None):
206220
query.lookup_pipeline = self.get_lookup_pipeline()
207221
query.project_fields = self.get_project_fields(columns)
208222
try:
209-
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
223+
query.mongo_query = (
224+
{"$expr": self.where.as_mql(self, self.connection)} if self.where else None
225+
)
210226
except FullResultSet:
211227
query.mongo_query = {}
212228
query.order_by(self._get_ordering())
@@ -431,7 +447,8 @@ def execute_update(self, update_spec, **kwargs):
431447
class SQLAggregateCompiler(SQLCompiler):
432448
def build_query(self, columns=None):
433449
query = self.query_class(self)
434-
query.project_fields = self.get_project_fields(tuple(self.query.annotation_select.items()))
450+
query.project_fields = self.get_project_fields(tuple(self.annotations.items()))
451+
query.aggregation_stage = self.get_aggregation_pipeline()
435452

436453
compiler = self.query.inner_query.get_compiler(
437454
self.using,
@@ -440,3 +457,6 @@ def build_query(self, columns=None):
440457
compiler.pre_sql_setup(with_col_aliases=False)
441458
query.sub_query = compiler.build_query()
442459
return query
460+
461+
def _make_result(self, result, columns=None):
462+
return [result[k] for k in self.query.annotation_select]

django_mongodb/functions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ def count(self, compiler, connection, **extra_context): # noqa: ARG001
121121
source_expressions = copy.get_source_expressions()
122122
condition = When(self.filter, then=Value(1))
123123
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
124-
lhs_mql = process_lhs(copy, compiler, connection)
124+
node = copy
125125
else:
126-
lhs_mql = Value(1).as_mql(compiler, connection)
127-
return {"$sum": lhs_mql}
126+
node = self
127+
# lhs_mql = process_lhs(self, compiler, connection)
128+
lhs_mql = process_lhs(node, compiler, connection)
129+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
130+
return {"$sum": {"$cond": {"if": null_cond, "then": 0, "else": 1}}}
128131

129132

130133
def extract(self, compiler, connection):

0 commit comments

Comments
 (0)