Skip to content

Commit 43ff544

Browse files
committed
Implementing aggregate compiler.
1 parent 02863b3 commit 43ff544

File tree

3 files changed

+47
-38
lines changed

3 files changed

+47
-38
lines changed

django_mongodb/compiler.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ class SQLCompiler(compiler.SQLCompiler):
1717
"""Base class for all Mongo compilers."""
1818

1919
query_class = MongoQuery
20+
_group_pipeline = None
2021

21-
def pre_sql_setup(self, *args, **kargs):
22-
pre_setup = super().pre_sql_setup(*args, **kargs)
22+
def pre_sql_setup(self, with_col_aliases=False):
23+
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2324
self.annotations = {}
2425
group = {}
2526
group_expressions = set()
@@ -30,17 +31,20 @@ def pre_sql_setup(self, *args, **kargs):
3031
else:
3132
replacements = {}
3233
for sub_expr in self._get_aggregate_expressions(expr):
33-
alias = f"__aggregation{aggregation_idx}"
34+
if sub_expr != expr:
35+
alias = f"__aggregation{aggregation_idx}"
36+
aggregation_idx += 1
37+
else:
38+
alias = target
39+
group_expressions |= set(sub_expr.get_group_by_cols())
3440
group[alias] = sub_expr.as_mql(self, self.connection)
35-
aggregation_idx += 1
3641
column_target = expr.output_field.__class__()
3742
column_target.set_attributes_from_name(alias)
3843
replacements[sub_expr] = Col(self.collection_name, column_target)
3944
result_expr = expr.replace_expressions(replacements)
4045

4146
self.annotations[target] = result_expr
4247
if group:
43-
"""
4448
order_by = self.get_order_by()
4549
for expr, (_, _, is_ref) in order_by:
4650
# Skip references to the SELECT clause, as all expressions in
@@ -50,7 +54,8 @@ def pre_sql_setup(self, *args, **kargs):
5054
having_group_by = self.having.get_group_by_cols() if self.having else ()
5155
for expr in having_group_by:
5256
group_expressions.add(expr)
53-
"""
57+
if isinstance(self.query.group_by, tuple | list):
58+
group_expressions |= set(self.query.group_by)
5459

5560
ids = (
5661
None
@@ -61,7 +66,6 @@ def pre_sql_setup(self, *args, **kargs):
6166
}
6267
)
6368
group["_id"] = ids
64-
6569
pipeline = [{"$group": group}]
6670
if ids:
6771
pipeline.append(
@@ -79,8 +83,8 @@ def pre_sql_setup(self, *args, **kargs):
7983
def execute_sql(
8084
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
8185
):
82-
self.pre_sql_setup()
8386
# QuerySet.count()
87+
self.pre_sql_setup()
8488
if self.query.annotations == {"__count": Count("*")}:
8589
return [self.get_count()]
8690

@@ -339,17 +343,6 @@ def get_lookup_pipeline(self):
339343
result += self.query.alias_map[alias].as_mql(self, self.connection)
340344
return result
341345

342-
def _get_aggregate_expressions2(self, expr):
343-
stack = [(None, expr)]
344-
while stack:
345-
parent, expr = stack.pop()
346-
if isinstance(expr, Aggregate):
347-
yield parent
348-
elif hasattr(expr, "get_source_expressions"):
349-
stack.extend(
350-
[((expr, idx), se) for idx, se in enumerate(expr.get_source_expressions())]
351-
)
352-
353346
def _get_aggregate_expressions(self, expr):
354347
stack = [expr]
355348
while stack:
@@ -496,4 +489,14 @@ def check_query(self):
496489

497490

498491
class SQLAggregateCompiler(SQLCompiler):
499-
pass
492+
def build_query(self, columns=None):
493+
query = self.query_class(self)
494+
query.project_fields = self.get_project_fields(tuple(self.query.annotation_select.items()))
495+
496+
compiler = self.query.inner_query.get_compiler(
497+
self.using,
498+
elide_empty=self.elide_empty,
499+
)
500+
compiler.pre_sql_setup(with_col_aliases=False)
501+
query.sub_query = compiler.build_query()
502+
return query

django_mongodb/expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def query(self, compiler, connection): # noqa: ARG001
7676

7777

7878
def ref(self, compiler, connection): # noqa: ARG001
79-
return self.refs
79+
return f"${self.refs}"
8080

8181

8282
def subquery(self, compiler, connection): # noqa: ARG001

django_mongodb/query.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, compiler):
4747
self.collection = self.compiler.get_collection()
4848
self.collection_name = self.compiler.collection_name
4949
self.mongo_query = getattr(compiler.query, "raw_query", {})
50+
self.sub_query = None
5051
self.lookup_pipeline = None
5152
self.annotation_stage = None
5253
self.project_fields = None
@@ -88,21 +89,8 @@ def delete(self):
8889
options = self.connection.operation_flags.get("delete", {})
8990
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9091

91-
@wrap_database_errors
92-
def get_cursor(self, count=False, limit=None, skip=None):
93-
"""
94-
Return a pymongo CommandCursor that can be iterated on to give the
95-
results of the query.
96-
97-
If `count` is True, return a single document with the number of
98-
documents that match the query.
99-
100-
Use `limit` or `skip` to override those options of the query.
101-
"""
102-
if self.query.low_mark == self.query.high_mark:
103-
return []
104-
# Construct the query pipeline.
105-
pipeline = []
92+
def get_pipeline(self, count=False, limit=None, skip=None):
93+
pipeline = [] if self.sub_query is None else self.sub_query.get_pipeline()
10694
if self.lookup_pipeline:
10795
pipeline.extend(self.lookup_pipeline)
10896
if self.mongo_query:
@@ -113,16 +101,34 @@ def get_cursor(self, count=False, limit=None, skip=None):
113101
pipeline.append({"$project": self.project_fields})
114102
if self.ordering:
115103
pipeline.append({"$sort": dict(self.ordering)})
104+
116105
if skip is not None:
117106
pipeline.append({"$skip": skip})
118107
elif self.query.low_mark > 0:
119108
pipeline.append({"$skip": self.query.low_mark})
109+
120110
if limit is not None:
121111
pipeline.append({"$limit": limit})
122112
elif self.query.high_mark is not None:
123113
pipeline.append({"$limit": self.query.high_mark - self.query.low_mark})
124-
if count:
125-
pipeline.append({"$group": {"_id": None, "__count": {"$sum": 1}}})
114+
115+
return pipeline
116+
117+
@wrap_database_errors
118+
def get_cursor(self, count=False, limit=None, skip=None):
119+
"""
120+
Return a pymongo CommandCursor that can be iterated on to give the
121+
results of the query.
122+
123+
If `count` is True, return a single document with the number of
124+
documents that match the query.
125+
126+
Use `limit` or `skip` to override those options of the query.
127+
"""
128+
if self.query.low_mark == self.query.high_mark:
129+
return []
130+
131+
pipeline = self.get_pipeline()
126132
return self.collection.aggregate(pipeline)
127133

128134

0 commit comments

Comments
 (0)