Skip to content

Commit 37bc7c6

Browse files
committed
Some fixes.
1 parent 2baa01d commit 37bc7c6

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

django_mongodb/compiler.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,29 @@ class SQLCompiler(compiler.SQLCompiler):
1818
query_class = MongoQuery
1919
_group_pipeline = None
2020

21+
def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx):
22+
replacements = {}
23+
group = {}
24+
for sub_expr in self._get_aggregate_expressions(expression):
25+
alias = f"__aggregation{aggregation_idx}" if sub_expr != expression else target
26+
27+
column_target = sub_expr.output_field.__class__()
28+
column_target.db_column = alias
29+
column_target.set_attributes_from_name(alias)
30+
inner_column = Col(self.collection_name, column_target)
31+
if sub_expr.distinct:
32+
inner_expr = sub_expr.as_mql(self, self.connection, force_filters=True)
33+
rhs = next(iter(inner_expr.values()))
34+
group[alias] = {"$addToSet": rhs}
35+
replacing_expr = sub_expr.copy()
36+
replacing_expr.set_source_expressions([inner_column])
37+
else:
38+
group[alias] = sub_expr.as_mql(self, self.connection)
39+
replacing_expr = inner_column
40+
41+
replacements[sub_expr] = replacing_expr
42+
return replacements, group
43+
2144
@staticmethod
2245
def _random_separtor():
2346
import random
@@ -32,39 +55,24 @@ def pre_sql_setup(self, with_col_aliases=False):
3255
self.annotations = {}
3356
group = {}
3457
group_expressions = set()
35-
aggregation_idx = 1
3658
all_replacements = {}
37-
for target, expr in self.query.annotation_select.items():
38-
if not expr.contains_aggregate:
39-
result_expr = expr
40-
else:
41-
replacements = {}
42-
for sub_expr in self._get_aggregate_expressions(expr):
43-
if sub_expr != expr:
44-
alias = f"__aggregation{aggregation_idx}"
45-
aggregation_idx += 1
46-
else:
47-
alias = target
48-
49-
column_target = expr.output_field.__class__()
50-
column_target.db_column = alias
51-
column_target.set_attributes_from_name(alias)
52-
inner_column = Col(self.collection_name, column_target)
53-
if sub_expr.distinct:
54-
inner_expr = sub_expr.as_mql(self, self.connection, force_filters=True)
55-
rhs = next(iter(inner_expr.values()))
56-
group[alias] = {"$addToSet": rhs}
57-
replacing_expr = sub_expr.copy()
58-
replacing_expr.set_source_expressions([inner_column])
59-
else:
60-
group[alias] = sub_expr.as_mql(self, self.connection)
61-
replacing_expr = inner_column
62-
63-
replacements[sub_expr] = replacing_expr
59+
for idx, (target, expr) in enumerate(self.query.annotation_select.items()):
60+
if expr.contains_aggregate:
61+
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target, idx)
6462
result_expr = expr.replace_expressions(replacements)
6563
all_replacements.update(replacements)
64+
group.update(expr_group)
65+
else:
66+
result_expr = expr
6667
group_expressions |= set(expr.get_group_by_cols())
6768
self.annotations[target] = result_expr
69+
70+
having_replacements, having_group = self._prepare_expressions_for_pipeline(
71+
self.having, None, len(self.query.annotation_select)
72+
)
73+
all_replacements.update(having_replacements)
74+
group.update(having_group)
75+
6876
if group or self.query.group_by:
6977
order_by = self.get_order_by()
7078
for expr, (_, _, is_ref) in order_by:
@@ -94,7 +102,14 @@ def pre_sql_setup(self, with_col_aliases=False):
94102
break
95103
SEPARATOR = f"__{random_string}__"
96104

105+
annotation_group_idx = 0
106+
97107
def _ccc(col):
108+
nonlocal annotation_group_idx
109+
110+
if not isinstance(col, Col):
111+
annotation_group_idx += 1
112+
return "__annotation_group_1"
98113
if self.collection_name == col.alias:
99114
return col.target.column
100115
return f"{col.alias}{SEPARATOR}{col.target.column}"
@@ -106,7 +121,6 @@ def _ccc(col):
106121
_ccc(col): col.as_mql(self, self.connection)
107122
# expression aren't needed in the group by clouse ()
108123
for col in group_expressions
109-
if isinstance(col, Col)
110124
}
111125
)
112126
pipeline = []
@@ -140,10 +154,7 @@ def _ccc(col):
140154
else:
141155
sets[key] = value
142156

143-
pipeline.append(
144-
# {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
145-
{"$addFields": sets}
146-
)
157+
pipeline.append({"$addFields": sets})
147158
if "_id" not in sets:
148159
pipeline.append({"$unset": "_id"})
149160

0 commit comments

Comments
 (0)