Skip to content

Commit cbedd74

Browse files
committed
test fixed.
1 parent 37bc7c6 commit cbedd74

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

django_mongodb/compiler.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@ class SQLCompiler(compiler.SQLCompiler):
1717

1818
query_class = MongoQuery
1919
_group_pipeline = None
20+
aggregation_idx = 0
2021

21-
def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx):
22+
def _get_colum_from_expression(self, expr, alias):
23+
column_target = expr.output_field.__class__()
24+
column_target.db_column = alias
25+
column_target.set_attributes_from_name(alias)
26+
return Col(self.collection_name, column_target)
27+
28+
def _prepare_expressions_for_pipeline(self, expression, target):
2229
replacements = {}
2330
group = {}
2431
for sub_expr in self._get_aggregate_expressions(expression):
25-
alias = f"__aggregation{aggregation_idx}" if sub_expr != expression else target
32+
alias = f"__aggregation{self.aggregation_idx}" if sub_expr != expression else target
33+
self.aggregation_idx += 1
2634

2735
column_target = sub_expr.output_field.__class__()
2836
column_target.db_column = alias
@@ -38,6 +46,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx)
3846
group[alias] = sub_expr.as_mql(self, self.connection)
3947
replacing_expr = inner_column
4048

49+
sub_expr.as_mql(self, self.connection)
4150
replacements[sub_expr] = replacing_expr
4251
return replacements, group
4352

@@ -56,19 +65,16 @@ def pre_sql_setup(self, with_col_aliases=False):
5665
group = {}
5766
group_expressions = set()
5867
all_replacements = {}
59-
for idx, (target, expr) in enumerate(self.query.annotation_select.items()):
68+
self.aggregation_idx = 0
69+
for target, expr in self.query.annotation_select.items():
6070
if expr.contains_aggregate:
61-
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target, idx)
62-
result_expr = expr.replace_expressions(replacements)
71+
replacements, expr_group = self._prepare_expressions_for_pipeline(expr, target)
6372
all_replacements.update(replacements)
6473
group.update(expr_group)
65-
else:
66-
result_expr = expr
6774
group_expressions |= set(expr.get_group_by_cols())
68-
self.annotations[target] = result_expr
6975

7076
having_replacements, having_group = self._prepare_expressions_for_pipeline(
71-
self.having, None, len(self.query.annotation_select)
77+
self.having, None
7278
)
7379
all_replacements.update(having_replacements)
7480
group.update(having_group)
@@ -109,7 +115,10 @@ def _ccc(col):
109115

110116
if not isinstance(col, Col):
111117
annotation_group_idx += 1
112-
return "__annotation_group_1"
118+
alias = f"__annotation_group_{annotation_group_idx}"
119+
col_expr = self._get_colum_from_expression(col, alias)
120+
all_replacements[col] = col_expr
121+
col = col_expr
113122
if self.collection_name == col.alias:
114123
return col.target.column
115124
return f"{col.alias}{SEPARATOR}{col.target.column}"
@@ -123,13 +132,17 @@ def _ccc(col):
123132
for col in group_expressions
124133
}
125134
)
135+
self.annotations = {
136+
target: expr.replace_expressions(all_replacements)
137+
for target, expr in self.query.annotation_select.items()
138+
}
126139
pipeline = []
127140
if not ids:
128141
group["_id"] = None
129142
pipeline.append({"$facet": {"group": [{"$group": group}]}})
130143
pipeline.append(
131144
{
132-
"$project": {
145+
"$addFields": {
133146
key: {
134147
"$getField": {
135148
"input": {"$arrayElemAt": ["$group", 0]},
@@ -173,6 +186,11 @@ def _ccc(col):
173186
else:
174187
self._group_pipeline = None
175188

189+
self.annotations = {
190+
target: expr.replace_expressions(all_replacements)
191+
for target, expr in self.query.annotation_select.items()
192+
}
193+
176194
return pre_setup
177195

178196
def execute_sql(
@@ -306,9 +324,8 @@ def build_query(self, columns=None):
306324
query.order_by(self._get_ordering())
307325
query.project_fields = self.get_project_fields(columns, ordering=query.ordering)
308326
try:
309-
query.mongo_query = (
310-
{"$expr": self.where.as_mql(self, self.connection)} if self.where else None
311-
)
327+
where = getattr(self, "where", self.query.where)
328+
query.mongo_query = {"$expr": where.as_mql(self, self.connection)} if where else None
312329
except FullResultSet:
313330
query.mongo_query = {}
314331
return query
@@ -487,7 +504,7 @@ def insert(self, docs, returning_fields=None):
487504
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
488505
def execute_sql(self, result_type=MULTI):
489506
cursor = Cursor()
490-
cursor.rowcount = self.build_query([self.query.get_meta().pk]).delete()
507+
cursor.rowcount = self.build_query().delete()
491508
return cursor
492509

493510
def check_query(self):

django_mongodb/features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
132132
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_numerical_aggregates",
133133
# Sum returns 0 instead of None in mongodb.
134134
"aggregation.test_filter_argument.FilteredAggregateTests.test_plain_annotate",
135+
"aggregation.tests.AggregateTestCase.test_aggregation_default_passed_another_aggregate",
135136
"aggregation.tests.AggregateTestCase.test_annotation_expressions",
136137
"aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate",
138+
# Manage empty result when the flag elide_empty is False
139+
"aggregation.tests.AggregateTestCase.test_empty_result_optimization",
137140
}
138141
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
139142
_django_test_expected_failures_bitwise = {

django_mongodb/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def extract(self, compiler, connection):
166166

167167
def func(self, compiler, connection):
168168
lhs_mql = process_lhs(self, compiler, connection)
169-
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
169+
operator = MONGO_OPERATORS.get(
170+
self.__class__, (self.extra["function"] if self.function is None else self.function).lower()
171+
)
170172
return {f"${operator}": lhs_mql}
171173

172174

django_mongodb/query_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.core.exceptions import FullResultSet
22
from django.db.models.aggregates import Aggregate
3-
from django.db.models.expressions import Value
3+
from django.db.models.expressions import Expression, Value
44

55

66
def is_direct_value(node):
@@ -37,6 +37,9 @@ def process_rhs(node, compiler, connection):
3737
value = value[0]
3838
if hasattr(node, "prep_lookup_value_mongo"):
3939
value = node.prep_lookup_value_mongo(value)
40+
# Can't apply converters to symbolic things.
41+
if isinstance(rhs, Expression):
42+
return value
4043
return connection.ops.prep_lookup_value(value, node.lhs.output_field, node.lookup_name)
4144

4245

0 commit comments

Comments
 (0)