Skip to content

Commit 2034a33

Browse files
committed
test fixed.
1 parent 419b37c commit 2034a33

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

django_mongodb/compiler.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,20 @@ class SQLCompiler(compiler.SQLCompiler):
1818

1919
query_class = MongoQuery
2020
_group_pipeline = None
21+
aggregation_idx = 0
2122

22-
def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx):
23+
def _get_colum_from_expression(self, expr, alias):
24+
column_target = expr.output_field.__class__()
25+
column_target.db_column = alias
26+
column_target.set_attributes_from_name(alias)
27+
return Col(self.collection_name, column_target)
28+
29+
def _prepare_expressions_for_pipeline(self, expression, target):
2330
replacements = {}
2431
group = {}
2532
for sub_expr in self._get_aggregate_expressions(expression):
26-
alias = f"__aggregation{aggregation_idx}" if sub_expr != expression else target
33+
alias = f"__aggregation{self.aggregation_idx}" if sub_expr != expression else target
34+
self.aggregation_idx += 1
2735

2836
column_target = sub_expr.output_field.__class__()
2937
column_target.db_column = alias
@@ -39,6 +47,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx)
3947
group[alias] = sub_expr.as_mql(self, self.connection)
4048
replacing_expr = inner_column
4149

50+
sub_expr.as_mql(self, self.connection)
4251
replacements[sub_expr] = replacing_expr
4352
return replacements, group
4453

@@ -57,19 +66,16 @@ def pre_sql_setup(self, with_col_aliases=False):
5766
group = {}
5867
group_expressions = set()
5968
all_replacements = {}
60-
for idx, (target, expr) in enumerate(self.query.annotation_select.items()):
69+
self.aggregation_idx = 0
70+
for target, expr in self.query.annotation_select.items():
6171
if expr.contains_aggregate:
62-
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target, idx)
63-
result_expr = expr.replace_expressions(replacements)
72+
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target)
6473
all_replacements.update(replacements)
6574
group.update(expr_group)
66-
else:
67-
result_expr = expr
6875
group_expressions |= set(expr.get_group_by_cols())
69-
self.annotations[target] = result_expr
7076

7177
having_replacements, having_group = self._prepare_expressions_for_pipeline(
72-
self.having, None, len(self.query.annotation_select)
78+
self.having, None
7379
)
7480
all_replacements.update(having_replacements)
7581
group.update(having_group)
@@ -110,7 +116,10 @@ def _ccc(col):
110116

111117
if not isinstance(col, Col):
112118
annotation_group_idx += 1
113-
return "__annotation_group_1"
119+
alias = f"__annotation_group_{annotation_group_idx}"
120+
col_expr = self._get_colum_from_expression(col, alias)
121+
all_replacements[col] = col_expr
122+
col = col_expr
114123
if self.collection_name == col.alias:
115124
return col.target.column
116125
return f"{col.alias}{SEPARATOR}{col.target.column}"
@@ -124,13 +133,17 @@ def _ccc(col):
124133
for col in group_expressions
125134
}
126135
)
136+
self.annotations = {
137+
target: expr.replace_expressions(all_replacements)
138+
for target, expr in self.query.annotation_select.items()
139+
}
127140
pipeline = []
128141
if not ids:
129142
group["_id"] = None
130143
pipeline.append({"$facet": {"group": [{"$group": group}]}})
131144
pipeline.append(
132145
{
133-
"$project": {
146+
"$addFields": {
134147
key: {
135148
"$getField": {
136149
"input": {"$arrayElemAt": ["$group", 0]},
@@ -174,6 +187,11 @@ def _ccc(col):
174187
else:
175188
self._group_pipeline = None
176189

190+
self.annotations = {
191+
target: expr.replace_expressions(all_replacements)
192+
for target, expr in self.query.annotation_select.items()
193+
}
194+
177195
return pre_setup
178196

179197
def execute_sql(
@@ -334,9 +352,8 @@ def build_query(self, columns=None):
334352
query.order_by(self._get_ordering())
335353
query.project_fields = self.get_project_fields(columns, ordering=query.ordering)
336354
try:
337-
query.mongo_query = (
338-
{"$expr": self.where.as_mql(self, self.connection)} if self.where else None
339-
)
355+
where = getattr(self, "where", self.query.where)
356+
query.mongo_query = {"$expr": where.as_mql(self, self.connection)} if where else None
340357
except FullResultSet:
341358
query.mongo_query = {}
342359
return query
@@ -515,7 +532,7 @@ def insert(self, docs, returning_fields=None):
515532
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
516533
def execute_sql(self, result_type=MULTI):
517534
cursor = Cursor()
518-
cursor.rowcount = self.build_query([self.query.get_meta().pk]).delete()
535+
cursor.rowcount = self.build_query().delete()
519536
return cursor
520537

521538
def check_query(self):

django_mongodb/features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
112112
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_numerical_aggregates",
113113
# Sum returns 0 instead of None in mongodb.
114114
"aggregation.test_filter_argument.FilteredAggregateTests.test_plain_annotate",
115+
"aggregation.tests.AggregateTestCase.test_aggregation_default_passed_another_aggregate",
115116
"aggregation.tests.AggregateTestCase.test_annotation_expressions",
116117
"aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate",
118+
# Manage empty result when the flag elide_empty is False
119+
"aggregation.tests.AggregateTestCase.test_empty_result_optimization",
117120
}
118121
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
119122
_django_test_expected_failures_bitwise = {

django_mongodb/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def extract(self, compiler, connection):
166166

167167
def func(self, compiler, connection):
168168
lhs_mql = process_lhs(self, compiler, connection)
169-
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
169+
operator = MONGO_OPERATORS.get(
170+
self.__class__, (self.extra["function"] if self.function is None else self.function).lower()
171+
)
170172
return {f"${operator}": lhs_mql}
171173

172174

django_mongodb/query_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.core.exceptions import FullResultSet
22
from django.db.models.aggregates import Aggregate
3-
from django.db.models.expressions import Value
3+
from django.db.models.expressions import Expression, Value
44

55

66
def is_direct_value(node):
@@ -37,6 +37,9 @@ def process_rhs(node, compiler, connection):
3737
value = value[0]
3838
if hasattr(node, "prep_lookup_value_mongo"):
3939
value = node.prep_lookup_value_mongo(value)
40+
# Can't apply converters to symbolic things.
41+
if isinstance(rhs, Expression):
42+
return value
4043
return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name)
4144

4245

0 commit comments

Comments
 (0)