Skip to content

Commit 7e80c1a

Browse files
committed
Refactor.
1 parent be32416 commit 7e80c1a

File tree

2 files changed

+125
-161
lines changed

2 files changed

+125
-161
lines changed

django_mongodb/compiler.py

Lines changed: 108 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from itertools import chain
1+
import itertools
22

33
from django.core.exceptions import EmptyResultSet, FullResultSet
44
from django.db import DatabaseError, IntegrityError, NotSupportedError
@@ -19,22 +19,33 @@ class SQLCompiler(compiler.SQLCompiler):
1919
"""Base class for all Mongo compilers."""
2020

2121
query_class = MongoQuery
22-
_group_pipeline = None
23-
aggregation_idx = 0
24-
25-
def _get_colum_from_expression(self, expr, alias):
22+
SEPARATOR = "10__MESSI__3"
23+
24+
def _get_group_alias_column(self, col, annotation_group_idx):
25+
"""Generate alias and replacement for group columns."""
26+
replacement = None
27+
if not isinstance(col, Col):
28+
alias = f"__annotation_group{next(annotation_group_idx)}"
29+
col_expr = self._get_column_from_expression(col, alias)
30+
replacement = col_expr
31+
col = col_expr
32+
if self.collection_name == col.alias:
33+
return col.target.column, replacement
34+
return f"{col.alias}{self.SEPARATOR}{col.target.column}", replacement
35+
36+
def _get_column_from_expression(self, expr, alias):
37+
"""Get column target from expression."""
2638
column_target = expr.output_field.__class__()
2739
column_target.db_column = alias
2840
column_target.set_attributes_from_name(alias)
2941
return Col(self.collection_name, column_target)
3042

31-
def _prepare_expressions_for_pipeline(self, expression, target):
43+
def _prepare_expressions_for_pipeline(self, expression, target, count):
44+
"""Prepare expressions for the MongoDB aggregation pipeline."""
3245
replacements = {}
3346
group = {}
3447
for sub_expr in self._get_aggregate_expressions(expression):
35-
alias = f"__aggregation{self.aggregation_idx}" if sub_expr != expression else target
36-
self.aggregation_idx += 1
37-
48+
alias = f"__aggregation{next(count)}" if sub_expr != expression else target
3849
column_target = sub_expr.output_field.__class__()
3950
column_target.db_column = alias
4051
column_target.set_attributes_from_name(alias)
@@ -57,127 +68,109 @@ def _prepare_expressions_for_pipeline(self, expression, target):
5768
replacements[sub_expr] = replacing_expr
5869
return replacements, group
5970

60-
@staticmethod
61-
def _random_separtor():
62-
import random
63-
import string
64-
65-
size = 6
66-
chars = string.ascii_uppercase + string.digits
67-
return "".join(random.choice(chars) for _ in range(size)) # noqa: S311
68-
69-
def pre_sql_setup(self, with_col_aliases=False):
70-
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
71-
self.annotations = {}
71+
def _prepare_annotations_for_group_pipeline(self):
72+
"""Prepare annotations for the MongoDB aggregation pipeline."""
73+
replacements = {}
7274
group = {}
73-
group_expressions = set()
74-
all_replacements = {}
75-
self.aggregation_idx = 0
75+
count = itertools.count(start=1)
7676
for target, expr in self.query.annotation_select.items():
7777
if expr.contains_aggregate:
78-
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target)
79-
all_replacements.update(replacements)
78+
new_replacements, expr_group = self._prepare_expressions_for_pipeline(
79+
expr, target, count
80+
)
81+
replacements.update(new_replacements)
8082
group.update(expr_group)
81-
group_expressions |= set(expr.get_group_by_cols())
8283

8384
having_replacements, having_group = self._prepare_expressions_for_pipeline(
84-
self.having, None
85+
self.having, None, count
8586
)
86-
all_replacements.update(having_replacements)
87+
replacements.update(having_replacements)
8788
group.update(having_group)
89+
return group, replacements
8890

89-
if group or self.query.group_by:
90-
order_by = self.get_order_by()
91-
for expr, (_, _, is_ref) in order_by:
92-
# Skip references to the SELECT clause, as all expressions in
93-
# the SELECT clause are already part of the GROUP BY.
94-
if not is_ref:
95-
group_expressions |= set(expr.get_group_by_cols())
96-
97-
for expr, *_ in self.select:
91+
def _get_group_id_expressions(self):
92+
"""Generate group ID expressions for the aggregation pipeline."""
93+
group_expressions = set()
94+
replacements = {}
95+
order_by = self.get_order_by()
96+
for expr, (_, _, is_ref) in order_by:
97+
if not is_ref:
9898
group_expressions |= set(expr.get_group_by_cols())
9999

100-
having_group_by = self.having.get_group_by_cols() if self.having else ()
101-
for expr in having_group_by:
102-
group_expressions.add(expr)
103-
if isinstance(self.query.group_by, tuple | list):
104-
group_expressions |= set(self.query.group_by)
105-
elif self.query.group_by is None:
106-
group_expressions = set()
100+
for expr, *_ in self.select:
101+
group_expressions |= set(expr.get_group_by_cols())
107102

108-
all_strings = "".join(
109-
str(col.as_mql(self, self.connection)) for col in group_expressions
110-
)
103+
having_group_by = self.having.get_group_by_cols() if self.having else ()
104+
for expr in having_group_by:
105+
group_expressions.add(expr)
106+
if isinstance(self.query.group_by, tuple | list):
107+
group_expressions |= set(self.query.group_by)
108+
elif self.query.group_by is None:
109+
group_expressions = set()
111110

112-
while True:
113-
random_string = self._random_separtor()
114-
if random_string not in all_strings:
115-
break
116-
SEPARATOR = f"__{random_string}__"
117-
118-
annotation_group_idx = 0
119-
120-
def _ccc(col):
121-
nonlocal annotation_group_idx
122-
123-
if not isinstance(col, Col):
124-
annotation_group_idx += 1
125-
alias = f"__annotation_group_{annotation_group_idx}"
126-
col_expr = self._get_colum_from_expression(col, alias)
127-
all_replacements[col] = col_expr
128-
col = col_expr
129-
if self.collection_name == col.alias:
130-
return col.target.column
131-
return f"{col.alias}{SEPARATOR}{col.target.column}"
132-
133-
ids = (
134-
None
135-
if not group_expressions
136-
else {
137-
_ccc(col): col.as_mql(self, self.connection)
138-
# expression aren't needed in the group by clouse ()
139-
for col in group_expressions
140-
}
141-
)
142-
self.annotations = {
143-
target: expr.replace_expressions(all_replacements)
144-
for target, expr in self.query.annotation_select.items()
145-
}
146-
pipeline = []
147-
if not ids:
148-
group["_id"] = None
149-
pipeline.append({"$facet": {"group": [{"$group": group}]}})
150-
pipeline.append(
151-
{
152-
"$addFields": {
153-
key: {
154-
"$getField": {
155-
"input": {"$arrayElemAt": ["$group", 0]},
156-
"field": key,
157-
}
111+
if not group_expressions:
112+
ids = None
113+
else:
114+
annotation_group_idx = itertools.count(start=1)
115+
ids = {}
116+
for col in group_expressions:
117+
alias, replacement = self._get_group_alias_column(col, annotation_group_idx)
118+
ids[alias] = col.as_mql(self, self.connection)
119+
if replacement is not None:
120+
replacements[col] = replacement
121+
122+
return ids, replacements
123+
124+
def _build_group_pipeline(self, ids, group):
125+
"""Build the aggregation pipeline for grouping."""
126+
pipeline = []
127+
if not ids:
128+
group["_id"] = None
129+
pipeline.append({"$facet": {"group": [{"$group": group}]}})
130+
pipeline.append(
131+
{
132+
"$addFields": {
133+
key: {
134+
"$getField": {
135+
"input": {"$arrayElemAt": ["$group", 0]},
136+
"field": key,
158137
}
159-
for key in group
160138
}
139+
for key in group
161140
}
162-
)
163-
else:
164-
group["_id"] = ids
165-
pipeline.append({"$group": group})
166-
sets = {}
167-
for key in ids:
168-
value = f"$_id.{key}"
169-
if SEPARATOR in key:
170-
subtable, field = key.split(SEPARATOR)
171-
if subtable not in sets:
172-
sets[subtable] = {}
173-
sets[subtable][field] = value
174-
else:
175-
sets[key] = value
176-
177-
pipeline.append({"$addFields": sets})
178-
if "_id" not in sets:
179-
pipeline.append({"$unset": "_id"})
141+
}
142+
)
143+
else:
144+
group["_id"] = ids
145+
pipeline.append({"$group": group})
146+
sets = {}
147+
for key in ids:
148+
value = f"$_id.{key}"
149+
if self.SEPARATOR in key:
150+
subtable, field = key.split(self.SEPARATOR)
151+
if subtable not in sets:
152+
sets[subtable] = {}
153+
sets[subtable][field] = value
154+
else:
155+
sets[key] = value
156+
157+
pipeline.append({"$addFields": sets})
158+
if "_id" not in sets:
159+
pipeline.append({"$unset": "_id"})
180160

161+
return pipeline
162+
163+
def pre_sql_setup(self, with_col_aliases=False):
164+
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
165+
group, all_replacements = self._prepare_annotations_for_group_pipeline()
166+
167+
# The query.group_by is either None (no GROUP BY at all), True
168+
# (group by select fields), or a list of expressions to be added
169+
# to the group by.
170+
if group or self.query.group_by:
171+
ids, replacements = self._get_group_id_expressions()
172+
all_replacements.update(replacements)
173+
pipeline = self._build_group_pipeline(ids, group)
181174
if self.having:
182175
pipeline.append(
183176
{
@@ -188,7 +181,6 @@ def _ccc(col):
188181
}
189182
}
190183
)
191-
192184
self._group_pipeline = pipeline
193185
else:
194186
self._group_pipeline = None
@@ -203,7 +195,6 @@ def _ccc(col):
203195
def execute_sql(
204196
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
205197
):
206-
# QuerySet.count()
207198
self.pre_sql_setup()
208199
columns = self.get_columns()
209200
try:
@@ -256,7 +247,7 @@ def results_iter(
256247

257248
fields = [s[0] for s in self.select[0 : self.col_count]]
258249
converters = self.get_converters(fields)
259-
rows = chain.from_iterable(results)
250+
rows = itertools.chain.from_iterable(results)
260251
if converters:
261252
rows = self.apply_converters(rows, converters)
262253
if tuple_expected:
@@ -320,34 +311,6 @@ def check_query(self):
320311
if any(key.startswith("_prefetch_related_") for key in self.query.extra):
321312
raise NotSupportedError("QuerySet.prefetch_related() is not supported on MongoDB.")
322313
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
323-
if any(
324-
isinstance(a, Aggregate) and not isinstance(a, Count)
325-
for a in self.query.annotations.values()
326-
):
327-
# raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.")
328-
pass
329-
330-
def get_count(self, check_exists=False):
331-
"""
332-
Count objects matching the current filters / constraints.
333-
334-
If `check_exists` is True, only check if any object matches.
335-
"""
336-
kwargs = {}
337-
# If this query is sliced, the limits will be set on the subquery.
338-
inner_query = getattr(self.query, "inner_query", None)
339-
low_mark = inner_query.low_mark if inner_query else 0
340-
high_mark = inner_query.high_mark if inner_query else None
341-
if low_mark > 0:
342-
kwargs["skip"] = low_mark
343-
if check_exists:
344-
kwargs["limit"] = 1
345-
elif high_mark is not None:
346-
kwargs["limit"] = high_mark - low_mark
347-
try:
348-
return self.build_query().count(**kwargs)
349-
except EmptyResultSet:
350-
return 0
351314

352315
def build_query(self, columns=None):
353316
"""Check if the query is supported and prepare a MongoQuery."""
@@ -540,6 +503,7 @@ def insert(self, docs, returning_fields=None):
540503

541504
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
542505
def execute_sql(self, result_type=MULTI):
506+
self.pre_sql_setup()
543507
cursor = Cursor()
544508
cursor.rowcount = self.build_query().delete()
545509
return cursor

django_mongodb/query.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,23 @@ def delete(self):
8989
options = self.connection.operation_flags.get("delete", {})
9090
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9191

92+
@wrap_database_errors
93+
def get_cursor(self, count=False, limit=None, skip=None):
94+
"""
95+
Return a pymongo CommandCursor that can be iterated on to give the
96+
results of the query.
97+
98+
If `count` is True, return a single document with the number of
99+
documents that match the query.
100+
101+
Use `limit` or `skip` to override those options of the query.
102+
"""
103+
if self.query.low_mark == self.query.high_mark:
104+
return []
105+
106+
pipeline = self.get_pipeline()
107+
return self.collection.aggregate(pipeline)
108+
92109
def get_pipeline(self, count=False, limit=None, skip=None):
93110
pipeline = [] if self.subquery is None else self.subquery.get_pipeline()
94111
if self.lookup_pipeline:
@@ -114,23 +131,6 @@ def get_pipeline(self, count=False, limit=None, skip=None):
114131

115132
return pipeline
116133

117-
@wrap_database_errors
118-
def get_cursor(self, count=False, limit=None, skip=None):
119-
"""
120-
Return a pymongo CommandCursor that can be iterated on to give the
121-
results of the query.
122-
123-
If `count` is True, return a single document with the number of
124-
documents that match the query.
125-
126-
Use `limit` or `skip` to override those options of the query.
127-
"""
128-
if self.query.low_mark == self.query.high_mark:
129-
return []
130-
131-
pipeline = self.get_pipeline()
132-
return self.collection.aggregate(pipeline)
133-
134134

135135
def join(self, compiler, connection):
136136
lookup_pipeline = []

0 commit comments

Comments
 (0)