Skip to content

Commit a528112

Browse files
committed
Refactor.
1 parent b4061d5 commit a528112

File tree

2 files changed

+125
-159
lines changed

2 files changed

+125
-159
lines changed

django_mongodb/compiler.py

Lines changed: 108 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
from django.core.exceptions import EmptyResultSet, FullResultSet
24
from django.db import DatabaseError, IntegrityError, NotSupportedError
35
from django.db.models import Count, Expression
@@ -17,22 +19,33 @@ class SQLCompiler(compiler.SQLCompiler):
1719
"""Base class for all Mongo compilers."""
1820

1921
query_class = MongoQuery
20-
_group_pipeline = None
21-
aggregation_idx = 0
22-
23-
def _get_colum_from_expression(self, expr, alias):
22+
SEPARATOR = "10__MESSI__3"
23+
24+
def _get_group_alias_column(self, col, annotation_group_idx):
25+
"""Generate alias and replacement for group columns."""
26+
replacement = None
27+
if not isinstance(col, Col):
28+
alias = f"__annotation_group{next(annotation_group_idx)}"
29+
col_expr = self._get_column_from_expression(col, alias)
30+
replacement = col_expr
31+
col = col_expr
32+
if self.collection_name == col.alias:
33+
return col.target.column, replacement
34+
return f"{col.alias}{self.SEPARATOR}{col.target.column}", replacement
35+
36+
def _get_column_from_expression(self, expr, alias):
37+
"""Get column target from expression."""
2438
column_target = expr.output_field.__class__()
2539
column_target.db_column = alias
2640
column_target.set_attributes_from_name(alias)
2741
return Col(self.collection_name, column_target)
2842

29-
def _prepare_expressions_for_pipeline(self, expression, target):
43+
def _prepare_expressions_for_pipeline(self, expression, target, count):
44+
"""Prepare expressions for the MongoDB aggregation pipeline."""
3045
replacements = {}
3146
group = {}
3247
for sub_expr in self._get_aggregate_expressions(expression):
33-
alias = f"__aggregation{self.aggregation_idx}" if sub_expr != expression else target
34-
self.aggregation_idx += 1
35-
48+
alias = f"__aggregation{next(count)}" if sub_expr != expression else target
3649
column_target = sub_expr.output_field.__class__()
3750
column_target.db_column = alias
3851
column_target.set_attributes_from_name(alias)
@@ -55,127 +68,109 @@ def _prepare_expressions_for_pipeline(self, expression, target):
5568
replacements[sub_expr] = replacing_expr
5669
return replacements, group
5770

58-
@staticmethod
59-
def _random_separtor():
60-
import random
61-
import string
62-
63-
size = 6
64-
chars = string.ascii_uppercase + string.digits
65-
return "".join(random.choice(chars) for _ in range(size)) # noqa: S311
66-
67-
def pre_sql_setup(self, with_col_aliases=False):
68-
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
69-
self.annotations = {}
71+
def _prepare_annotations_for_group_pipeline(self):
72+
"""Prepare annotations for the MongoDB aggregation pipeline."""
73+
replacements = {}
7074
group = {}
71-
group_expressions = set()
72-
all_replacements = {}
73-
self.aggregation_idx = 0
75+
count = itertools.count(start=1)
7476
for target, expr in self.query.annotation_select.items():
7577
if expr.contains_aggregate:
76-
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target)
77-
all_replacements.update(replacements)
78+
new_replacements, expr_group = self._prepare_expressions_for_pipeline(
79+
expr, target, count
80+
)
81+
replacements.update(new_replacements)
7882
group.update(expr_group)
79-
group_expressions |= set(expr.get_group_by_cols())
8083

8184
having_replacements, having_group = self._prepare_expressions_for_pipeline(
82-
self.having, None
85+
self.having, None, count
8386
)
84-
all_replacements.update(having_replacements)
87+
replacements.update(having_replacements)
8588
group.update(having_group)
89+
return group, replacements
8690

87-
if group or self.query.group_by:
88-
order_by = self.get_order_by()
89-
for expr, (_, _, is_ref) in order_by:
90-
# Skip references to the SELECT clause, as all expressions in
91-
# the SELECT clause are already part of the GROUP BY.
92-
if not is_ref:
93-
group_expressions |= set(expr.get_group_by_cols())
94-
95-
for expr, *_ in self.select:
91+
def _get_group_id_expressions(self):
92+
"""Generate group ID expressions for the aggregation pipeline."""
93+
group_expressions = set()
94+
replacements = {}
95+
order_by = self.get_order_by()
96+
for expr, (_, _, is_ref) in order_by:
97+
if not is_ref:
9698
group_expressions |= set(expr.get_group_by_cols())
9799

98-
having_group_by = self.having.get_group_by_cols() if self.having else ()
99-
for expr in having_group_by:
100-
group_expressions.add(expr)
101-
if isinstance(self.query.group_by, tuple | list):
102-
group_expressions |= set(self.query.group_by)
103-
elif self.query.group_by is None:
104-
group_expressions = set()
100+
for expr, *_ in self.select:
101+
group_expressions |= set(expr.get_group_by_cols())
105102

106-
all_strings = "".join(
107-
str(col.as_mql(self, self.connection)) for col in group_expressions
108-
)
103+
having_group_by = self.having.get_group_by_cols() if self.having else ()
104+
for expr in having_group_by:
105+
group_expressions.add(expr)
106+
if isinstance(self.query.group_by, tuple | list):
107+
group_expressions |= set(self.query.group_by)
108+
elif self.query.group_by is None:
109+
group_expressions = set()
109110

110-
while True:
111-
random_string = self._random_separtor()
112-
if random_string not in all_strings:
113-
break
114-
SEPARATOR = f"__{random_string}__"
115-
116-
annotation_group_idx = 0
117-
118-
def _ccc(col):
119-
nonlocal annotation_group_idx
120-
121-
if not isinstance(col, Col):
122-
annotation_group_idx += 1
123-
alias = f"__annotation_group_{annotation_group_idx}"
124-
col_expr = self._get_colum_from_expression(col, alias)
125-
all_replacements[col] = col_expr
126-
col = col_expr
127-
if self.collection_name == col.alias:
128-
return col.target.column
129-
return f"{col.alias}{SEPARATOR}{col.target.column}"
130-
131-
ids = (
132-
None
133-
if not group_expressions
134-
else {
135-
_ccc(col): col.as_mql(self, self.connection)
136-
# expression aren't needed in the group by clouse ()
137-
for col in group_expressions
138-
}
139-
)
140-
self.annotations = {
141-
target: expr.replace_expressions(all_replacements)
142-
for target, expr in self.query.annotation_select.items()
143-
}
144-
pipeline = []
145-
if not ids:
146-
group["_id"] = None
147-
pipeline.append({"$facet": {"group": [{"$group": group}]}})
148-
pipeline.append(
149-
{
150-
"$addFields": {
151-
key: {
152-
"$getField": {
153-
"input": {"$arrayElemAt": ["$group", 0]},
154-
"field": key,
155-
}
111+
if not group_expressions:
112+
ids = None
113+
else:
114+
annotation_group_idx = itertools.count(start=1)
115+
ids = {}
116+
for col in group_expressions:
117+
alias, replacement = self._get_group_alias_column(col, annotation_group_idx)
118+
ids[alias] = col.as_mql(self, self.connection)
119+
if replacement is not None:
120+
replacements[col] = replacement
121+
122+
return ids, replacements
123+
124+
def _build_group_pipeline(self, ids, group):
125+
"""Build the aggregation pipeline for grouping."""
126+
pipeline = []
127+
if not ids:
128+
group["_id"] = None
129+
pipeline.append({"$facet": {"group": [{"$group": group}]}})
130+
pipeline.append(
131+
{
132+
"$addFields": {
133+
key: {
134+
"$getField": {
135+
"input": {"$arrayElemAt": ["$group", 0]},
136+
"field": key,
156137
}
157-
for key in group
158138
}
139+
for key in group
159140
}
160-
)
161-
else:
162-
group["_id"] = ids
163-
pipeline.append({"$group": group})
164-
sets = {}
165-
for key in ids:
166-
value = f"$_id.{key}"
167-
if SEPARATOR in key:
168-
subtable, field = key.split(SEPARATOR)
169-
if subtable not in sets:
170-
sets[subtable] = {}
171-
sets[subtable][field] = value
172-
else:
173-
sets[key] = value
174-
175-
pipeline.append({"$addFields": sets})
176-
if "_id" not in sets:
177-
pipeline.append({"$unset": "_id"})
141+
}
142+
)
143+
else:
144+
group["_id"] = ids
145+
pipeline.append({"$group": group})
146+
sets = {}
147+
for key in ids:
148+
value = f"$_id.{key}"
149+
if self.SEPARATOR in key:
150+
subtable, field = key.split(self.SEPARATOR)
151+
if subtable not in sets:
152+
sets[subtable] = {}
153+
sets[subtable][field] = value
154+
else:
155+
sets[key] = value
156+
157+
pipeline.append({"$addFields": sets})
158+
if "_id" not in sets:
159+
pipeline.append({"$unset": "_id"})
178160

161+
return pipeline
162+
163+
def pre_sql_setup(self, with_col_aliases=False):
164+
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
165+
group, all_replacements = self._prepare_annotations_for_group_pipeline()
166+
167+
# The query.group_by is either None (no GROUP BY at all), True
168+
# (group by select fields), or a list of expressions to be added
169+
# to the group by.
170+
if group or self.query.group_by:
171+
ids, replacements = self._get_group_id_expressions()
172+
all_replacements.update(replacements)
173+
pipeline = self._build_group_pipeline(ids, group)
179174
if self.having:
180175
pipeline.append(
181176
{
@@ -186,7 +181,6 @@ def _ccc(col):
186181
}
187182
}
188183
)
189-
190184
self._group_pipeline = pipeline
191185
else:
192186
self._group_pipeline = None
@@ -201,7 +195,6 @@ def _ccc(col):
201195
def execute_sql(
202196
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
203197
):
204-
# QuerySet.count()
205198
self.pre_sql_setup()
206199
columns = self.get_columns()
207200
try:
@@ -291,34 +284,6 @@ def check_query(self):
291284
if any(key.startswith("_prefetch_related_") for key in self.query.extra):
292285
raise NotSupportedError("QuerySet.prefetch_related() is not supported on MongoDB.")
293286
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
294-
if any(
295-
isinstance(a, Aggregate) and not isinstance(a, Count)
296-
for a in self.query.annotations.values()
297-
):
298-
# raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.")
299-
pass
300-
301-
def get_count(self, check_exists=False):
302-
"""
303-
Count objects matching the current filters / constraints.
304-
305-
If `check_exists` is True, only check if any object matches.
306-
"""
307-
kwargs = {}
308-
# If this query is sliced, the limits will be set on the subquery.
309-
inner_query = getattr(self.query, "inner_query", None)
310-
low_mark = inner_query.low_mark if inner_query else 0
311-
high_mark = inner_query.high_mark if inner_query else None
312-
if low_mark > 0:
313-
kwargs["skip"] = low_mark
314-
if check_exists:
315-
kwargs["limit"] = 1
316-
elif high_mark is not None:
317-
kwargs["limit"] = high_mark - low_mark
318-
try:
319-
return self.build_query().count(**kwargs)
320-
except EmptyResultSet:
321-
return 0
322287

323288
def build_query(self, columns=None):
324289
"""Check if the query is supported and prepare a MongoQuery."""
@@ -511,6 +476,7 @@ def insert(self, docs, returning_fields=None):
511476

512477
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
513478
def execute_sql(self, result_type=MULTI):
479+
self.pre_sql_setup()
514480
cursor = Cursor()
515481
cursor.rowcount = self.build_query().delete()
516482
return cursor

django_mongodb/query.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,23 @@ def delete(self):
9393
options = self.connection.operation_flags.get("delete", {})
9494
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9595

96+
@wrap_database_errors
97+
def get_cursor(self, count=False, limit=None, skip=None):
98+
"""
99+
Return a pymongo CommandCursor that can be iterated on to give the
100+
results of the query.
101+
102+
If `count` is True, return a single document with the number of
103+
documents that match the query.
104+
105+
Use `limit` or `skip` to override those options of the query.
106+
"""
107+
if self.query.low_mark == self.query.high_mark:
108+
return []
109+
110+
pipeline = self.get_pipeline()
111+
return self.collection.aggregate(pipeline)
112+
96113
def get_pipeline(self, count=False, limit=None, skip=None):
97114
pipeline = [] if self.subquery is None else self.subquery.get_pipeline()
98115
if self.lookup_pipeline:
@@ -118,23 +135,6 @@ def get_pipeline(self, count=False, limit=None, skip=None):
118135

119136
return pipeline
120137

121-
@wrap_database_errors
122-
def get_cursor(self, count=False, limit=None, skip=None):
123-
"""
124-
Return a pymongo CommandCursor that can be iterated on to give the
125-
results of the query.
126-
127-
If `count` is True, return a single document with the number of
128-
documents that match the query.
129-
130-
Use `limit` or `skip` to override those options of the query.
131-
"""
132-
if self.query.low_mark == self.query.high_mark:
133-
return []
134-
135-
pipeline = self.get_pipeline()
136-
return self.collection.aggregate(pipeline)
137-
138138

139139
def join(self, compiler, connection):
140140
lookup_pipeline = []

0 commit comments

Comments
 (0)