Skip to content

Commit 554dacb

Browse files
committed
edits
1 parent 094e494 commit 554dacb

File tree

5 files changed

+30
-41
lines changed

5 files changed

+30
-41
lines changed

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,9 @@ Migrations for 'admin':
116116
- `extra()`
117117
- `prefetch_related()`
118118

119-
120119
- `QuerySet.delete()` and `update()` do not support queries that span multiple
121120
collections.
122121

123-
- `aggregate()` Sort with aggregate may not work well.
124-
125122
- `Subquery`, `Exists`, and using a `QuerySet` in `QuerySet.annotate()` aren't
126123
supported.
127124

django_mongodb/compiler.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _get_column_from_expression(self, expr, alias):
4545
return Col(self.collection_name, column_target)
4646

4747
def _prepare_expressions_for_pipeline(self, expression, target, count):
48-
"""Prepare expressions for the MongoDB aggregation pipeline."""
48+
"""Prepare expressions for the aggregation pipeline."""
4949
replacements = {}
5050
group = {}
5151
for sub_expr in self._get_aggregate_expressions(expression):
@@ -62,12 +62,12 @@ def _prepare_expressions_for_pipeline(self, expression, target, count):
6262
else:
6363
group[alias] = sub_expr.as_mql(self, self.connection)
6464
replacing_expr = inner_column
65-
65+
# Count must return 0 rather than null.
6666
if isinstance(sub_expr, Count):
6767
replacing_expr = Coalesce(replacing_expr, 0)
68+
# Variance = StdDev^2
6869
if isinstance(sub_expr, Variance):
6970
replacing_expr = Power(replacing_expr, 2)
70-
7171
replacements[sub_expr] = replacing_expr
7272
return replacements, group
7373

@@ -170,10 +170,10 @@ def _build_group_pipeline(self, ids, group):
170170
def pre_sql_setup(self, with_col_aliases=False):
171171
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
172172
group, all_replacements = self._prepare_annotations_for_group_pipeline()
173-
174-
# The query.group_by is either None (no GROUP BY at all), True
175-
# (group by select fields), or a list of expressions to be added
176-
# to the group by.
173+
# query.group_by is either:
174+
# - None: no GROUP BY
175+
# - True: group by select fields
176+
# - a list of expressions to group by.
177177
if group or self.query.group_by:
178178
ids, replacements = self._get_group_id_expressions(order_by)
179179
all_replacements.update(replacements)
@@ -194,7 +194,6 @@ def pre_sql_setup(self, with_col_aliases=False):
194194
target: expr.replace_expressions(all_replacements)
195195
for target, expr in self.query.annotation_select.items()
196196
}
197-
198197
return extra_select, order_by, group_by
199198

200199
def execute_sql(
@@ -455,18 +454,16 @@ def get_project_fields(self, columns=None, ordering=None):
455454
# If name != column, then this is an annotatation referencing
456455
# another column.
457456
fields[name] = 1 if name == column else f"${column}"
458-
459457
if fields:
460458
# Add related fields.
461459
for alias in self.query.alias_map:
462460
if self.query.alias_refcount[alias] and self.collection_name != alias:
463461
fields[alias] = 1
464-
462+
# Add order_by() fields.
465463
for column, _ in ordering or []:
466464
foreign_table = column.split(".", 1)[0] if "." in column else None
467465
if column not in fields and foreign_table not in fields:
468466
fields[column] = 1
469-
470467
return fields
471468

472469

django_mongodb/features.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
7272
"db_functions.math.test_round.RoundTests.test_integer_with_negative_precision",
7373
# Truncating in another timezone doesn't work becauase MongoDB converts
7474
# the result back to UTC.
75-
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database",
7675
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_func_with_timezone",
7776
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation",
7877
# Length of null considered zero rather than null.
@@ -104,7 +103,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
104103
# QuerySet.explain() not implemented:
105104
# https://github.com/mongodb-labs/django-mongodb/issues/28
106105
"queries.test_explain.ExplainUnsupportedTests.test_message",
107-
# Sum returns 0 instead of None in mongodb.
106+
# The $sum aggregation returns 0 instead of None for null.
108107
"aggregation.test_filter_argument.FilteredAggregateTests.test_plain_annotate",
109108
"aggregation.tests.AggregateTestCase.test_aggregation_default_passed_another_aggregate",
110109
"aggregation.tests.AggregateTestCase.test_annotation_expressions",
@@ -504,7 +503,9 @@ def django_test_expected_failures(self):
504503
"timezones.tests.NewDatabaseTests.test_cursor_explicit_time_zone",
505504
"timezones.tests.NewDatabaseTests.test_raw_sql",
506505
},
507-
"Custom functions with SQL don't work on MongoDB.": {
506+
"Custom aggregations/functions with SQL don't work on MongoDB.": {
507+
"aggregation.tests.AggregateTestCase.test_add_implementation",
508+
"aggregation.tests.AggregateTestCase.test_multi_arg_aggregate",
508509
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",
509510
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions",
510511
},
@@ -533,6 +534,7 @@ def django_test_expected_failures(self):
533534
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_quarter_func_boundaries",
534535
},
535536
"TruncDate database function not supported.": {
537+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database",
536538
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_date_func",
537539
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_date_none",
538540
"db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_lookup_name_sql_injection",
@@ -588,9 +590,6 @@ def django_test_expected_failures(self):
588590
},
589591
"Test not applicable for MongoDB's SQLCompiler.": {
590592
"queries.test_iterator.QuerySetIteratorTests",
591-
# Custom aggregations:
592-
"aggregation.tests.AggregateTestCase.test_add_implementation",
593-
"aggregation.tests.AggregateTestCase.test_multi_arg_aggregate",
594593
},
595594
}
596595

django_mongodb/functions.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141

4242
MONGO_AGGREGATIONS = {
4343
Count: "sum",
44-
StdDev: "stdDev",
45-
Variance: "stdDev",
44+
StdDev: "stdDev", # Samp or Pop suffix added in aggregate().
45+
Variance: "stdDev", # Likewise.
4646
}
4747
MONGO_OPERATORS = {
4848
Ceil: "ceil",
@@ -79,11 +79,11 @@ def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001
7979
node = self
8080
lhs_mql = process_lhs(node, compiler, connection)
8181
operator = MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
82-
if "_SAMP" in self.function:
82+
# Add suffixes to StdDev/Variance.
83+
if self.function.endswith("_SAMP"):
8384
operator += "Samp"
84-
elif "_POP" in self.function:
85+
elif self.function.endswith("_POP"):
8586
operator += "Pop"
86-
8787
return {f"${operator}": lhs_mql}
8888

8989

@@ -119,26 +119,24 @@ def cot(self, compiler, connection):
119119

120120
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
121121
"""
122-
When resolve_inner_expression is True, it returns the argument as MQL that resolves as a value.
123-
This is used to count different elements, so the inner values are returned
124-
to be pushed into a set.
122+
When resolve_inner_expression is True, return the argument as MQL that
123+
resolves as a value. This is used to count different elements, so the inner
124+
values are returned to be pushed into a set.
125125
"""
126126
if not self.distinct or resolve_inner_expression:
127127
if self.filter:
128-
copy = self.copy()
129-
copy.filter = None
130-
source_expressions = copy.get_source_expressions()
128+
node = self.copy()
129+
node.filter = None
130+
source_expressions = node.get_source_expressions()
131131
filter_ = deepcopy(self.filter)
132132
filter_.add(
133133
WhereNode([Exact(source_expressions[0], Value(None))], negated=True),
134134
filter_.default,
135135
)
136136
condition = When(filter_, then=Value(1))
137-
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
138-
node = copy
137+
node.set_source_expressions([Case(condition)] + source_expressions[1:])
139138
inner_expression = process_lhs(node, compiler, connection)
140139
else:
141-
node = self
142140
lhs_mql = process_lhs(self, compiler, connection)
143141
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
144142
inner_expression = {
@@ -147,10 +145,10 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
147145
if resolve_inner_expression:
148146
return inner_expression
149147
return {"$sum": inner_expression}
150-
151-
# When count is called with distinct without the flag, we sum the size of the set.
148+
# If distinct=True or resolve_inner_expression=False, sum the size
149+
# of the set.
152150
lhs_mql = process_lhs(self, compiler, connection)
153-
# And subtract 1 if None is in the set (it shouldn't have been counted).
151+
# Subtract 1 if None is in the set (it shouldn't have been counted).
154152
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
155153
return {"$add": [{"$size": lhs_mql}, exits_null]}
156154

django_mongodb/query.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,10 @@ def get_cursor(self):
8787
"""
8888
if self.query.low_mark == self.query.high_mark:
8989
return []
90-
91-
pipeline = self.get_pipeline()
92-
return self.collection.aggregate(pipeline)
90+
return self.collection.aggregate(self.get_pipeline())
9391

9492
def get_pipeline(self):
95-
pipeline = [] if self.subquery is None else self.subquery.get_pipeline()
93+
pipeline = self.subquery.get_pipeline() if self.subquery else []
9694
if self.lookup_pipeline:
9795
pipeline.extend(self.lookup_pipeline)
9896
if self.mongo_query:

0 commit comments

Comments
 (0)