Skip to content

Commit 98ad0b7

Browse files
committed
Edits.
1 parent 8b3e3c2 commit 98ad0b7

File tree

4 files changed

+27
-32
lines changed

4 files changed

+27
-32
lines changed

django_mongodb/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def build_query(self, columns=None):
319319
"""Check if the query is supported and prepare a MongoQuery."""
320320
self.check_query()
321321
query = self.query_class(self)
322-
query.aggregation_stage = self.get_aggregation_pipeline()
322+
query.aggregation_pipeline = self.get_aggregation_pipeline()
323323
query.lookup_pipeline = self.get_lookup_pipeline()
324324
query.order_by(self._get_ordering())
325325
query.project_fields = self.get_project_fields(columns, ordering=query.ordering)
@@ -583,7 +583,7 @@ class SQLAggregateCompiler(SQLCompiler):
583583
def build_query(self, columns=None):
584584
query = self.query_class(self)
585585
query.project_fields = self.get_project_fields(tuple(self.annotations.items()))
586-
query.aggregation_stage = self.get_aggregation_pipeline()
586+
query.aggregation_pipeline = self.get_aggregation_pipeline()
587587

588588
compiler = self.query.inner_query.get_compiler(
589589
self.using,

django_mongodb/features.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
3939
"aggregation.tests.AggregateTestCase.test_grouped_annotation_in_group_by",
4040
"aggregation.tests.AggregateTestCase.test_non_grouped_annotation_not_in_group_by",
4141
"aggregation.tests.AggregateTestCase.test_values_annotation_with_expression",
42-
"aggregation.tests.AggregateTestCase.test_annotate_ordering",
43-
"aggregation.tests.AggregateTestCase.test_even_more_aggregate",
4442
"annotations.tests.NonAggregateAnnotationTestCase.test_order_by_aggregate",
4543
"model_fields.test_jsonfield.TestQuerying.test_ordering_grouping_by_count",
4644
"ordering.tests.OrderingTests.test_default_ordering_does_not_affect_group_by",
@@ -72,6 +70,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
7270
"db_functions.math.test_round.RoundTests.test_integer_with_negative_precision",
7371
# Truncating in another timezone doesn't work becauase MongoDB converts
7472
# the result back to UTC.
73+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database",
7574
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_func_with_timezone",
7675
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation",
7776
# Length of null considered zero rather than null.
@@ -110,6 +109,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
110109
"aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate",
111110
# Manage empty result when the flag elide_empty is False
112111
"aggregation.tests.AggregateTestCase.test_empty_result_optimization",
112+
# Incorrect order: pipeline does not order by the right fields.
113+
"aggregation.tests.AggregateTestCase.test_annotate_ordering",
114+
"aggregation.tests.AggregateTestCase.test_even_more_aggregate",
113115
}
114116
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
115117
_django_test_expected_failures_bitwise = {
@@ -138,7 +140,6 @@ def django_test_expected_failures(self):
138140
"expressions.tests.BasicExpressionsTests.test_object_create_with_aggregate",
139141
"expressions.tests.BasicExpressionsTests.test_object_create_with_f_expression_in_subquery",
140142
# PI()
141-
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_decimal_from_database",
142143
"db_functions.math.test_round.RoundTests.test_decimal_with_precision",
143144
"db_functions.math.test_round.RoundTests.test_float_with_precision",
144145
},
@@ -471,6 +472,7 @@ def django_test_expected_failures(self):
471472
"update.tests.AdvancedTests.test_update_annotated_multi_table_queryset",
472473
},
473474
"Test inspects query for SQL": {
475+
"aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned",
474476
"aggregation.tests.AggregateAnnotationPruningTests.test_unreferenced_aggregate_annotation_pruned",
475477
"aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_pruned",
476478
"aggregation.tests.AggregateAnnotationPruningTests.test_referenced_aggregate_annotation_kept",
@@ -543,6 +545,7 @@ def django_test_expected_failures(self):
543545
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_time_none",
544546
},
545547
"MongoDB can't annotate ($project) a function like PI().": {
548+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_decimal_from_database",
546549
"db_functions.math.test_pi.PiTests.test",
547550
},
548551
"Can't cast from date to datetime without MongoDB interpreting the new value in UTC.": {
@@ -584,11 +587,8 @@ def django_test_expected_failures(self):
584587
"Test not applicable for MongoDB's SQLCompiler.": {
585588
"queries.test_iterator.QuerySetIteratorTests",
586589
# Custom aggregations:
587-
"aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned",
588590
"aggregation.tests.AggregateTestCase.test_add_implementation",
589-
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database",
590591
"aggregation.tests.AggregateTestCase.test_multi_arg_aggregate",
591-
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database",
592592
},
593593
}
594594

django_mongodb/functions.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from copy import deepcopy
22

33
from django.db import NotSupportedError
4-
from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
4+
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
55
from django.db.models.expressions import Case, Func, Star, Value, When
66
from django.db.models.functions import Now
77
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
@@ -40,13 +40,9 @@
4040
from .query_utils import process_lhs
4141

4242
MONGO_AGGREGATIONS = {
43-
Avg: "avg",
4443
Count: "sum",
45-
Max: "max",
46-
Min: "min",
47-
StdDev: "stdDevPop",
48-
Sum: "sum",
49-
Variance: "stdDevPop",
44+
StdDev: "stdDev",
45+
Variance: "stdDev",
5046
}
5147
MONGO_OPERATORS = {
5248
Ceil: "ceil",
@@ -74,18 +70,20 @@
7470

7571
def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001
7672
if self.filter:
77-
copy = self.copy()
78-
copy.filter = None
79-
source_expressions = copy.get_source_expressions()
73+
node = self.copy()
74+
node.filter = None
75+
source_expressions = node.get_source_expressions()
8076
condition = When(self.filter, then=source_expressions[0])
81-
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
82-
node = copy
77+
node.set_source_expressions([Case(condition)] + source_expressions[1:])
8378
else:
8479
node = self
8580
lhs_mql = process_lhs(node, compiler, connection)
86-
operator = MONGO_AGGREGATIONS.get(self.__class__)
87-
if self.__class__ in (StdDev, Variance) and "_SAMP" in self.function:
88-
operator = operator.replace("Pop", "Samp")
81+
operator = MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
82+
if "_SAMP" in self.function:
83+
operator += "Samp"
84+
elif "_POP" in self.function:
85+
operator += "Pop"
86+
8987
return {f"${operator}": lhs_mql}
9088

9189

@@ -119,8 +117,8 @@ def cot(self, compiler, connection):
119117
return {"$divide": [1, {"$tan": lhs_mql}]}
120118

121119

122-
def count(self, compiler, connection, **extra_context):
123-
if not self.distinct or extra_context.get("force_filters"):
120+
def count(self, compiler, connection, force_filters=False, **extra_context): # noqa: ARG001
121+
if not self.distinct or force_filters:
124122
if self.filter:
125123
copy = self.copy()
126124
copy.filter = None
@@ -160,9 +158,7 @@ def extract(self, compiler, connection):
160158

161159
def func(self, compiler, connection):
162160
lhs_mql = process_lhs(self, compiler, connection)
163-
operator = MONGO_OPERATORS.get(
164-
self.__class__, (self.extra["function"] if self.function is None else self.function).lower()
165-
)
161+
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
166162
return {f"${operator}": lhs_mql}
167163

168164

django_mongodb/query.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def __init__(self, compiler):
4949
self.mongo_query = getattr(compiler.query, "raw_query", {})
5050
self.subquery = None
5151
self.lookup_pipeline = None
52-
self.annotation_stage = None
5352
self.project_fields = None
54-
self.aggregation_stage = None
53+
self.aggregation_pipeline = None
5554

5655
def __repr__(self):
5756
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -98,8 +97,8 @@ def get_pipeline(self):
9897
pipeline.extend(self.lookup_pipeline)
9998
if self.mongo_query:
10099
pipeline.append({"$match": self.mongo_query})
101-
if self.aggregation_stage:
102-
pipeline.extend(self.aggregation_stage)
100+
if self.aggregation_pipeline:
101+
pipeline.extend(self.aggregation_pipeline)
103102
if self.project_fields:
104103
pipeline.append({"$project": self.project_fields})
105104
if self.ordering:

0 commit comments

Comments
 (0)