Skip to content

Commit a4410a1

Browse files
WaVEVtimgraham
authored andcommitted
Add having.
1 parent de5baab commit a4410a1

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

django_mongodb/compiler.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ class SQLCompiler(compiler.SQLCompiler):
2222
def pre_sql_setup(self, with_col_aliases=False):
2323
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2424
self.annotations = {}
25+
# mongo_having = self.having.copy() if self.having else None
2526
group = {}
2627
group_expressions = set()
2728
aggregation_idx = 1
29+
all_replacements = {}
2830
for target, expr in self.query.annotation_select.items():
2931
if not expr.contains_aggregate:
3032
result_expr = expr
@@ -42,7 +44,7 @@ def pre_sql_setup(self, with_col_aliases=False):
4244
column_target.set_attributes_from_name(alias)
4345
replacements[sub_expr] = Col(self.collection_name, column_target)
4446
result_expr = expr.replace_expressions(replacements)
45-
47+
all_replacements.update(replacements)
4648
self.annotations[target] = result_expr
4749
if group:
4850
order_by = self.get_order_by()
@@ -62,7 +64,9 @@ def pre_sql_setup(self, with_col_aliases=False):
6264
if not group_expressions
6365
else {
6466
col.target.column: col.as_mql(self, self.connection)
67+
# expression aren't needed in the group by clouse ()
6568
for col in group_expressions
69+
if isinstance(col, Col)
6670
}
6771
)
6872
group["_id"] = ids
@@ -72,7 +76,17 @@ def pre_sql_setup(self, with_col_aliases=False):
7276
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
7377
)
7478
if "_id" not in ids:
75-
pipeline.append({"$unSet": "$_id"})
79+
pipeline.append({"$unset": "_id"})
80+
if self.having:
81+
pipeline.append(
82+
{
83+
"$match": {
84+
"$expr": self.having.replace_expressions(all_replacements).as_mql(
85+
self, self.connection
86+
)
87+
}
88+
}
89+
)
7690

7791
self._group_pipeline = pipeline
7892
else:
@@ -238,7 +252,9 @@ def build_query(self, columns=None):
238252
query.lookup_pipeline = self.get_lookup_pipeline()
239253
query.project_fields = self.get_project_fields(columns)
240254
try:
241-
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
255+
query.mongo_query = (
256+
{"$expr": self.where.as_mql(self, self.connection)} if self.where else None
257+
)
242258
except FullResultSet:
243259
query.mongo_query = {}
244260
query.order_by(self._get_ordering())
@@ -489,7 +505,8 @@ def check_query(self):
489505
class SQLAggregateCompiler(SQLCompiler):
490506
def build_query(self, columns=None):
491507
query = self.query_class(self)
492-
query.project_fields = self.get_project_fields(tuple(self.query.annotation_select.items()))
508+
query.project_fields = self.get_project_fields(tuple(self.annotations.items()))
509+
query.aggregation_stage = self.get_aggregation_pipeline()
493510

494511
compiler = self.query.inner_query.get_compiler(
495512
self.using,
@@ -498,3 +515,6 @@ def build_query(self, columns=None):
498515
compiler.pre_sql_setup(with_col_aliases=False)
499516
query.sub_query = compiler.build_query()
500517
return query
518+
519+
def _make_result(self, result, columns=None):
520+
return [result[k] for k in self.query.annotation_select]

django_mongodb/functions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ def count(self, compiler, connection, **extra_context): # noqa: ARG001
121121
source_expressions = copy.get_source_expressions()
122122
condition = When(self.filter, then=Value(1))
123123
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
124-
lhs_mql = process_lhs(copy, compiler, connection)
124+
node = copy
125125
else:
126-
lhs_mql = Value(1).as_mql(compiler, connection)
127-
return {"$sum": lhs_mql}
126+
node = self
127+
# lhs_mql = process_lhs(self, compiler, connection)
128+
lhs_mql = process_lhs(node, compiler, connection)
129+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
130+
return {"$sum": {"$cond": {"if": null_cond, "then": 0, "else": 1}}}
128131

129132

130133
def extract(self, compiler, connection):

0 commit comments

Comments
 (0)