Skip to content

Commit a2e49c0

Browse files
committed
Support ordering with null first / null last.
1 parent 8e3839d commit a2e49c0

File tree

5 files changed

+27
-29
lines changed

5 files changed

+27
-29
lines changed

django_mongodb/compiler.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from django.db import DatabaseError, IntegrityError, NotSupportedError
66
from django.db.models import Count, Expression
77
from django.db.models.aggregates import Aggregate, Variance
8-
from django.db.models.expressions import Col, Ref, Value
8+
from django.db.models.expressions import Case, Col, Ref, Value, When
99
from django.db.models.functions.comparison import Coalesce
1010
from django.db.models.functions.math import Power
11+
from django.db.models.lookups import IsNull
1112
from django.db.models.sql import compiler
1213
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1314
from django.utils.functional import cached_property
@@ -333,13 +334,14 @@ def build_query(self, columns=None):
333334
query = self.query_class(self)
334335
query.aggregation_pipeline = self.get_aggregation_pipeline()
335336
query.lookup_pipeline = self.get_lookup_pipeline()
336-
ordering_fields, order, need_extra_fields = self.preprocess_orderby()
337+
orderby_annotations, ordering_fields, order = self.preprocess_orderby()
337338
query.project_fields = self.get_project_fields(columns, ordering_fields)
338339
query.ordering = order
339-
if need_extra_fields and columns is None:
340-
query.extra_fields = self.get_project_fields(
341-
((name, field) for name, field in ordering_fields if name.startswith("__order"))
342-
)
340+
341+
# Post pipeline fields, some of them need some refs to be compted, so we add this fields
342+
# after the main part of the pipeline has finished.
343+
if orderby_annotations:
344+
query.extra_fields = self.get_project_fields(orderby_annotations, add_fields=True)
343345
try:
344346
where = getattr(self, "where", self.query.where)
345347
query.mongo_query = (
@@ -412,7 +414,7 @@ def _get_aggregate_expressions(self, expr):
412414
def get_aggregation_pipeline(self):
413415
return self._group_pipeline
414416

415-
def get_project_fields(self, columns=None, ordering=None):
417+
def get_project_fields(self, columns=None, ordering=None, add_fields=False):
416418
fields = {}
417419
for name, expr in columns or []:
418420
try:
@@ -430,7 +432,7 @@ def get_project_fields(self, columns=None, ordering=None):
430432
# another column.
431433
fields[name] = 1 if name == column else f"${column}"
432434

433-
if fields:
435+
if fields and not add_fields:
434436
# Add related fields.
435437
for alias in self.query.alias_map:
436438
if self.query.alias_refcount[alias] and self.collection_name != alias:
@@ -446,20 +448,25 @@ def get_project_fields(self, columns=None, ordering=None):
446448

447449
def preprocess_orderby(self):
448450
fields = {}
451+
orderby_annotations = {}
449452
result = SON()
450-
need_extra_fields = False
451453
idx = itertools.count(start=1)
452454
for order in self._order_by or []:
453-
if isinstance(order.expression, Ref):
454-
fieldname = order.expression.refs
455-
elif isinstance(order.expression, Col):
455+
if isinstance(order.expression, Col | Ref):
456456
fieldname = order.expression.as_mql(self, self.connection).removeprefix("$")
457+
fields[fieldname] = order.expression
457458
else:
458459
fieldname = f"__order{next(idx)}"
459-
need_extra_fields = True
460-
fields[fieldname] = order.expression
460+
orderby_annotations[fieldname] = order.expression
461+
462+
if order.nulls_first or order.nulls_last:
463+
null_fieldname = f"__order{next(idx)}"
464+
condition = When(IsNull(order.expression, True), then=Value(1))
465+
orderby_annotations[null_fieldname] = Case(condition, default=Value(0))
466+
result[null_fieldname] = DESCENDING if order.nulls_first else ASCENDING
467+
461468
result[fieldname] = DESCENDING if order.descending else ASCENDING
462-
return tuple(fields.items()), result, need_extra_fields
469+
return tuple(orderby_annotations.items()), tuple(fields.items()), result
463470

464471

465472
class SQLInsertCompiler(SQLCompiler):

django_mongodb/features.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def django_test_expected_failures(self):
377377
"QuerySet.distinct() is not supported.": {
378378
"aggregation.tests.AggregateTestCase.test_sum_distinct_aggregate",
379379
"lookup.tests.LookupTests.test_lookup_collision_distinct",
380+
"ordering.tests.OrderingTests.test_orders_nulls_first_on_filtered_subquery"
380381
"queries.tests.ExcludeTest17600.test_exclude_plain_distinct",
381382
"queries.tests.ExcludeTest17600.test_exclude_with_q_is_equal_to_plain_exclude",
382383
"queries.tests.ExcludeTest17600.test_exclude_with_q_object_distinct",
@@ -420,11 +421,6 @@ def django_test_expected_failures(self):
420421
"queries.tests.ValuesQuerysetTests.test_named_values_list_without_fields",
421422
"select_related.tests.SelectRelatedTests.test_select_related_with_extra",
422423
},
423-
"Ordering a QuerySet by null_first/nulls_last is not supported on MongoDB.": {
424-
"ordering.tests.OrderingTests.test_order_by_nulls_first",
425-
"ordering.tests.OrderingTests.test_order_by_nulls_last",
426-
"ordering.tests.OrderingTests.test_orders_nulls_first_on_filtered_subquery",
427-
},
428424
"QuerySet.update() crash: Unrecognized expression '$count'": {
429425
"update.tests.AdvancedTests.test_update_annotated_multi_table_queryset",
430426
},

django_mongodb/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def wrapped(self, compiler, connection):
205205
lhs_mql = process_lhs(self, compiler, connection)
206206
return {
207207
"$cond": {
208-
"if": {"$eq": [lhs_mql, None]},
208+
"if": connection.mongo_operators["isnull"](lhs_mql, True),
209209
"then": None,
210210
"else": {f"${operator}": lhs_mql},
211211
}

django_mongodb/operations.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from bson.decimal128 import Decimal128
88
from django.conf import settings
9-
from django.db import DataError, NotSupportedError
9+
from django.db import DataError
1010
from django.db.backends.base.operations import BaseDatabaseOperations
11-
from django.db.models.expressions import Combinable, OrderBy
11+
from django.db.models.expressions import Combinable
1212
from django.utils import timezone
1313
from django.utils.regex_helper import _lazy_re_compile
1414

@@ -158,11 +158,6 @@ def convert_uuidfield_value(self, value, expression, connection):
158158
value = uuid.UUID(value)
159159
return value
160160

161-
def check_expression_support(self, expression):
162-
if isinstance(expression, OrderBy) and (expression.nulls_first or expression.nulls_last):
163-
option = "null_first" if expression.nulls_first else "nulls_last"
164-
raise NotSupportedError(f"Ordering a QuerySet by {option} is not supported on MongoDB.")
165-
166161
def combine_expression(self, connector, sub_expressions):
167162
lhs, rhs = sub_expressions
168163
if connector == Combinable.BITLEFTSHIFT:

django_mongodb/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def get_pipeline(self):
8686
if self.extra_fields:
8787
pipeline.append({"$addFields": self.extra_fields})
8888
if self.ordering:
89-
pipeline.append({"$sort": dict(self.ordering)})
89+
pipeline.append({"$sort": self.ordering})
9090
if self.query.low_mark > 0:
9191
pipeline.append({"$skip": self.query.low_mark})
9292
if self.query.high_mark is not None:

0 commit comments

Comments
 (0)