Skip to content

Commit 419b37c

Browse files
committed
Some fixes.
1 parent 6e04ee9 commit 419b37c

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

django_mongodb/compiler.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,29 @@ class SQLCompiler(compiler.SQLCompiler):
1919
query_class = MongoQuery
2020
_group_pipeline = None
2121

22+
def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx):
23+
replacements = {}
24+
group = {}
25+
for sub_expr in self._get_aggregate_expressions(expression):
26+
alias = f"__aggregation{aggregation_idx}" if sub_expr != expression else target
27+
28+
column_target = sub_expr.output_field.__class__()
29+
column_target.db_column = alias
30+
column_target.set_attributes_from_name(alias)
31+
inner_column = Col(self.collection_name, column_target)
32+
if sub_expr.distinct:
33+
inner_expr = sub_expr.as_mql(self, self.connection, force_filters=True)
34+
rhs = next(iter(inner_expr.values()))
35+
group[alias] = {"$addToSet": rhs}
36+
replacing_expr = sub_expr.copy()
37+
replacing_expr.set_source_expressions([inner_column])
38+
else:
39+
group[alias] = sub_expr.as_mql(self, self.connection)
40+
replacing_expr = inner_column
41+
42+
replacements[sub_expr] = replacing_expr
43+
return replacements, group
44+
2245
@staticmethod
2346
def _random_separtor():
2447
import random
@@ -33,39 +56,24 @@ def pre_sql_setup(self, with_col_aliases=False):
3356
self.annotations = {}
3457
group = {}
3558
group_expressions = set()
36-
aggregation_idx = 1
3759
all_replacements = {}
38-
for target, expr in self.query.annotation_select.items():
39-
if not expr.contains_aggregate:
40-
result_expr = expr
41-
else:
42-
replacements = {}
43-
for sub_expr in self._get_aggregate_expressions(expr):
44-
if sub_expr != expr:
45-
alias = f"__aggregation{aggregation_idx}"
46-
aggregation_idx += 1
47-
else:
48-
alias = target
49-
50-
column_target = expr.output_field.__class__()
51-
column_target.db_column = alias
52-
column_target.set_attributes_from_name(alias)
53-
inner_column = Col(self.collection_name, column_target)
54-
if sub_expr.distinct:
55-
inner_expr = sub_expr.as_mql(self, self.connection, force_filters=True)
56-
rhs = next(iter(inner_expr.values()))
57-
group[alias] = {"$addToSet": rhs}
58-
replacing_expr = sub_expr.copy()
59-
replacing_expr.set_source_expressions([inner_column])
60-
else:
61-
group[alias] = sub_expr.as_mql(self, self.connection)
62-
replacing_expr = inner_column
63-
64-
replacements[sub_expr] = replacing_expr
60+
for idx, (target, expr) in enumerate(self.query.annotation_select.items()):
61+
if expr.contains_aggregate:
62+
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target, idx)
6563
result_expr = expr.replace_expressions(replacements)
6664
all_replacements.update(replacements)
65+
group.update(expr_group)
66+
else:
67+
result_expr = expr
6768
group_expressions |= set(expr.get_group_by_cols())
6869
self.annotations[target] = result_expr
70+
71+
having_replacements, having_group = self._prepare_expressions_for_pipeline(
72+
self.having, None, len(self.query.annotation_select)
73+
)
74+
all_replacements.update(having_replacements)
75+
group.update(having_group)
76+
6977
if group or self.query.group_by:
7078
order_by = self.get_order_by()
7179
for expr, (_, _, is_ref) in order_by:
@@ -95,7 +103,14 @@ def pre_sql_setup(self, with_col_aliases=False):
95103
break
96104
SEPARATOR = f"__{random_string}__"
97105

106+
annotation_group_idx = 0
107+
98108
def _ccc(col):
109+
nonlocal annotation_group_idx
110+
111+
if not isinstance(col, Col):
112+
annotation_group_idx += 1
113+
return "__annotation_group_1"
99114
if self.collection_name == col.alias:
100115
return col.target.column
101116
return f"{col.alias}{SEPARATOR}{col.target.column}"
@@ -107,7 +122,6 @@ def _ccc(col):
107122
_ccc(col): col.as_mql(self, self.connection)
108123
# expression aren't needed in the group by clouse ()
109124
for col in group_expressions
110-
if isinstance(col, Col)
111125
}
112126
)
113127
pipeline = []
@@ -141,10 +155,7 @@ def _ccc(col):
141155
else:
142156
sets[key] = value
143157

144-
pipeline.append(
145-
# {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
146-
{"$addFields": sets}
147-
)
158+
pipeline.append({"$addFields": sets})
148159
if "_id" not in sets:
149160
pipeline.append({"$unset": "_id"})
150161

0 commit comments

Comments
 (0)