Skip to content

Commit 4742745

Browse files
committed
Implementing aggregate compiler.
1 parent f86a0a2 commit 4742745

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

django_mongodb/compiler.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ class SQLCompiler(compiler.SQLCompiler):
1616
"""Base class for all Mongo compilers."""
1717

1818
query_class = MongoQuery
19+
_group_pipeline = None
1920

20-
def pre_sql_setup(self, *args, **kargs):
21-
pre_setup = super().pre_sql_setup(*args, **kargs)
21+
def pre_sql_setup(self, with_col_aliases=False):
22+
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2223
self.annotations = {}
2324
group = {}
2425
group_expressions = set()
@@ -29,17 +30,20 @@ def pre_sql_setup(self, *args, **kargs):
2930
else:
3031
replacements = {}
3132
for sub_expr in self._get_aggregate_expressions(expr):
32-
alias = f"__aggregation{aggregation_idx}"
33+
if sub_expr != expr:
34+
alias = f"__aggregation{aggregation_idx}"
35+
aggregation_idx += 1
36+
else:
37+
alias = target
38+
group_expressions |= set(sub_expr.get_group_by_cols())
3339
group[alias] = sub_expr.as_mql(self, self.connection)
34-
aggregation_idx += 1
3540
column_target = expr.output_field.__class__()
3641
column_target.set_attributes_from_name(alias)
3742
replacements[sub_expr] = Col(self.collection_name, column_target)
3843
result_expr = expr.replace_expressions(replacements)
3944

4045
self.annotations[target] = result_expr
4146
if group:
42-
"""
4347
order_by = self.get_order_by()
4448
for expr, (_, _, is_ref) in order_by:
4549
# Skip references to the SELECT clause, as all expressions in
@@ -49,7 +53,8 @@ def pre_sql_setup(self, *args, **kargs):
4953
having_group_by = self.having.get_group_by_cols() if self.having else ()
5054
for expr in having_group_by:
5155
group_expressions.add(expr)
52-
"""
56+
if isinstance(self.query.group_by, tuple | list):
57+
group_expressions |= set(self.query.group_by)
5358

5459
ids = (
5560
None
@@ -60,7 +65,6 @@ def pre_sql_setup(self, *args, **kargs):
6065
}
6166
)
6267
group["_id"] = ids
63-
6468
pipeline = [{"$group": group}]
6569
if ids:
6670
pipeline.append(
@@ -78,8 +82,8 @@ def pre_sql_setup(self, *args, **kargs):
7882
def execute_sql(
7983
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
8084
):
81-
self.pre_sql_setup()
8285
# QuerySet.count()
86+
self.pre_sql_setup()
8387
if self.query.annotations == {"__count": Count("*")}:
8488
return [self.get_count()]
8589

@@ -299,17 +303,6 @@ def get_lookup_pipeline(self):
299303
result += self.query.alias_map[alias].as_mql(self, self.connection)
300304
return result
301305

302-
def _get_aggregate_expressions2(self, expr):
303-
stack = [(None, expr)]
304-
while stack:
305-
parent, expr = stack.pop()
306-
if isinstance(expr, Aggregate):
307-
yield parent
308-
elif hasattr(expr, "get_source_expressions"):
309-
stack.extend(
310-
[((expr, idx), se) for idx, se in enumerate(expr.get_source_expressions())]
311-
)
312-
313306
def _get_aggregate_expressions(self, expr):
314307
stack = [expr]
315308
while stack:
@@ -436,4 +429,14 @@ def execute_update(self, update_spec, **kwargs):
436429

437430

438431
class SQLAggregateCompiler(SQLCompiler):
439-
pass
432+
def build_query(self, columns=None):
433+
query = self.query_class(self)
434+
query.project_fields = self.get_project_fields(tuple(self.query.annotation_select.items()))
435+
436+
compiler = self.query.inner_query.get_compiler(
437+
self.using,
438+
elide_empty=self.elide_empty,
439+
)
440+
compiler.pre_sql_setup(with_col_aliases=False)
441+
query.sub_query = compiler.build_query()
442+
return query

django_mongodb/expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def query(self, compiler, connection): # noqa: ARG001
7070

7171

7272
def ref(self, compiler, connection): # noqa: ARG001
73-
return self.refs
73+
return f"${self.refs}"
7474

7575

7676
def subquery(self, compiler, connection): # noqa: ARG001

django_mongodb/query.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, compiler):
4747
self.collection = self.compiler.get_collection()
4848
self.collection_name = self.compiler.collection_name
4949
self.mongo_query = getattr(compiler.query, "raw_query", {})
50+
self.sub_query = None
5051
self.lookup_pipeline = None
5152
self.annotation_stage = None
5253
self.project_fields = None
@@ -96,12 +97,8 @@ def delete(self):
9697
options = self.connection.operation_flags.get("delete", {})
9798
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9899

99-
@wrap_database_errors
100-
def get_cursor(self):
101-
if self.query.low_mark == self.query.high_mark:
102-
return []
103-
# Construct the query pipeline.
104-
pipeline = []
100+
def get_pipeline(self):
101+
pipeline = [] if self.sub_query is None else self.sub_query.get_pipeline()
105102
if self.lookup_pipeline:
106103
pipeline.extend(self.lookup_pipeline)
107104
if self.mongo_query:
@@ -116,6 +113,13 @@ def get_cursor(self):
116113
pipeline.append({"$skip": self.query.low_mark})
117114
if self.query.high_mark is not None:
118115
pipeline.append({"$limit": self.query.high_mark - self.query.low_mark})
116+
return pipeline
117+
118+
@wrap_database_errors
119+
def get_cursor(self):
120+
if self.query.low_mark == self.query.high_mark:
121+
return []
122+
pipeline = self.get_pipeline()
119123
return self.collection.aggregate(pipeline)
120124

121125

0 commit comments

Comments
 (0)