Skip to content

Commit 723c07e

Browse files
committed
Implementing aggregate compiler.
1 parent a5dec57 commit 723c07e

File tree

3 files changed

+47
-38
lines changed

3 files changed

+47
-38
lines changed

django_mongodb/compiler.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ class SQLCompiler(compiler.SQLCompiler):
1616
"""Base class for all Mongo compilers."""
1717

1818
query_class = MongoQuery
19+
_group_pipeline = None
1920

20-
def pre_sql_setup(self, *args, **kargs):
21-
pre_setup = super().pre_sql_setup(*args, **kargs)
21+
def pre_sql_setup(self, with_col_aliases=False):
22+
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2223
self.annotations = {}
2324
group = {}
2425
group_expressions = set()
@@ -29,17 +30,20 @@ def pre_sql_setup(self, *args, **kargs):
2930
else:
3031
replacements = {}
3132
for sub_expr in self._get_aggregate_expressions(expr):
32-
alias = f"__aggregation{aggregation_idx}"
33+
if sub_expr != expr:
34+
alias = f"__aggregation{aggregation_idx}"
35+
aggregation_idx += 1
36+
else:
37+
alias = target
38+
group_expressions |= set(sub_expr.get_group_by_cols())
3339
group[alias] = sub_expr.as_mql(self, self.connection)
34-
aggregation_idx += 1
3540
column_target = expr.output_field.__class__()
3641
column_target.set_attributes_from_name(alias)
3742
replacements[sub_expr] = Col(self.collection_name, column_target)
3843
result_expr = expr.replace_expressions(replacements)
3944

4045
self.annotations[target] = result_expr
4146
if group:
42-
"""
4347
order_by = self.get_order_by()
4448
for expr, (_, _, is_ref) in order_by:
4549
# Skip references to the SELECT clause, as all expressions in
@@ -49,7 +53,8 @@ def pre_sql_setup(self, *args, **kargs):
4953
having_group_by = self.having.get_group_by_cols() if self.having else ()
5054
for expr in having_group_by:
5155
group_expressions.add(expr)
52-
"""
56+
if isinstance(self.query.group_by, tuple | list):
57+
group_expressions |= set(self.query.group_by)
5358

5459
ids = (
5560
None
@@ -60,7 +65,6 @@ def pre_sql_setup(self, *args, **kargs):
6065
}
6166
)
6267
group["_id"] = ids
63-
6468
pipeline = [{"$group": group}]
6569
if ids:
6670
pipeline.append(
@@ -78,8 +82,8 @@ def pre_sql_setup(self, *args, **kargs):
7882
def execute_sql(
7983
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
8084
):
81-
self.pre_sql_setup()
8285
# QuerySet.count()
86+
self.pre_sql_setup()
8387
if self.query.annotations == {"__count": Count("*")}:
8488
return [self.get_count()]
8589

@@ -305,17 +309,6 @@ def get_lookup_pipeline(self):
305309
result += self.query.alias_map[alias].as_mql(self, self.connection)
306310
return result
307311

308-
def _get_aggregate_expressions2(self, expr):
309-
stack = [(None, expr)]
310-
while stack:
311-
parent, expr = stack.pop()
312-
if isinstance(expr, Aggregate):
313-
yield parent
314-
elif hasattr(expr, "get_source_expressions"):
315-
stack.extend(
316-
[((expr, idx), se) for idx, se in enumerate(expr.get_source_expressions())]
317-
)
318-
319312
def _get_aggregate_expressions(self, expr):
320313
stack = [expr]
321314
while stack:
@@ -442,4 +435,14 @@ def execute_update(self, update_spec, **kwargs):
442435

443436

444437
class SQLAggregateCompiler(SQLCompiler):
445-
pass
438+
def build_query(self, columns=None):
439+
query = self.query_class(self)
440+
query.project_fields = self.get_project_fields(tuple(self.query.annotation_select.items()))
441+
442+
compiler = self.query.inner_query.get_compiler(
443+
self.using,
444+
elide_empty=self.elide_empty,
445+
)
446+
compiler.pre_sql_setup(with_col_aliases=False)
447+
query.sub_query = compiler.build_query()
448+
return query

django_mongodb/expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def query(self, compiler, connection): # noqa: ARG001
7070

7171

7272
def ref(self, compiler, connection): # noqa: ARG001
73-
return self.refs
73+
return f"${self.refs}"
7474

7575

7676
def subquery(self, compiler, connection): # noqa: ARG001

django_mongodb/query.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, compiler):
4747
self.collection = self.compiler.get_collection()
4848
self.collection_name = self.compiler.collection_name
4949
self.mongo_query = getattr(compiler.query, "raw_query", {})
50+
self.sub_query = None
5051
self.lookup_pipeline = None
5152
self.annotation_stage = None
5253
self.project_fields = None
@@ -92,21 +93,8 @@ def delete(self):
9293
options = self.connection.operation_flags.get("delete", {})
9394
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9495

95-
@wrap_database_errors
96-
def get_cursor(self, count=False, limit=None, skip=None):
97-
"""
98-
Return a pymongo CommandCursor that can be iterated on to give the
99-
results of the query.
100-
101-
If `count` is True, return a single document with the number of
102-
documents that match the query.
103-
104-
Use `limit` or `skip` to override those options of the query.
105-
"""
106-
if self.query.low_mark == self.query.high_mark:
107-
return []
108-
# Construct the query pipeline.
109-
pipeline = []
96+
def get_pipeline(self, count=False, limit=None, skip=None):
97+
pipeline = [] if self.sub_query is None else self.sub_query.get_pipeline()
11098
if self.lookup_pipeline:
11199
pipeline.extend(self.lookup_pipeline)
112100
if self.mongo_query:
@@ -117,16 +105,34 @@ def get_cursor(self, count=False, limit=None, skip=None):
117105
pipeline.append({"$project": self.project_fields})
118106
if self.ordering:
119107
pipeline.append({"$sort": dict(self.ordering)})
108+
120109
if skip is not None:
121110
pipeline.append({"$skip": skip})
122111
elif self.query.low_mark > 0:
123112
pipeline.append({"$skip": self.query.low_mark})
113+
124114
if limit is not None:
125115
pipeline.append({"$limit": limit})
126116
elif self.query.high_mark is not None:
127117
pipeline.append({"$limit": self.query.high_mark - self.query.low_mark})
128-
if count:
129-
pipeline.append({"$group": {"_id": None, "__count": {"$sum": 1}}})
118+
119+
return pipeline
120+
121+
@wrap_database_errors
122+
def get_cursor(self, count=False, limit=None, skip=None):
123+
"""
124+
Return a pymongo CommandCursor that can be iterated on to give the
125+
results of the query.
126+
127+
If `count` is True, return a single document with the number of
128+
documents that match the query.
129+
130+
Use `limit` or `skip` to override those options of the query.
131+
"""
132+
if self.query.low_mark == self.query.high_mark:
133+
return []
134+
135+
pipeline = self.get_pipeline()
130136
return self.collection.aggregate(pipeline)
131137

132138

0 commit comments

Comments
 (0)