7
7
from django .db import IntegrityError , NotSupportedError
8
8
from django .db .models import Count
9
9
from django .db .models .aggregates import Aggregate , Variance
10
- from django .db .models .expressions import Case , Col , Ref , Value , When
10
+ from django .db .models .expressions import Case , Col , OrderBy , Ref , Value , When
11
11
from django .db .models .functions .comparison import Coalesce
12
12
from django .db .models .functions .math import Power
13
13
from django .db .models .lookups import IsNull
@@ -381,8 +381,10 @@ def get_columns(self):
381
381
which should be loaded by the query.
382
382
"""
383
383
select_mask = self .query .get_select_mask ()
384
- columns = (
385
- self .get_default_columns (select_mask ) if self .query .default_cols else self .query .select
384
+ columns = filter (
385
+ # The extra order by columns are handled by order_by_objs variables.
386
+ lambda col : not isinstance (col , OrderBy ),
387
+ self .get_default_columns (select_mask ) if self .query .default_cols else self .query .select ,
386
388
)
387
389
# Populate QuerySet.select_related() data.
388
390
related_columns = []
@@ -439,13 +441,10 @@ def get_combinator_queries(self):
439
441
* self .query .annotation_select ,
440
442
)
441
443
)
442
- compiler_ .pre_sql_setup (with_col_aliases = False )
443
- # Avoid $project (columns=None) if unneeded.
444
- columns = (
445
- compiler_ .get_columns ()
446
- if compiler_ .query .annotations or not compiler_ .query .default_cols
447
- else None
448
- )
444
+ compiler_ .pre_sql_setup ()
445
+ # Standardize columns as main query required.
446
+ _ , exprs = zip (* compiler_ .get_columns (), strict = True )
447
+ columns = tuple (zip (self .query .values_select , exprs , strict = True ))
449
448
parts .append ((compiler_ .build_query (columns ), compiler_ .collection_name ))
450
449
451
450
except EmptyResultSet :
@@ -454,7 +453,9 @@ def get_combinator_queries(self):
454
453
if self .query .combinator == "union" :
455
454
continue
456
455
raise
457
-
456
+ # Raise EmptyResultSet if all the combinator queries are empty.
457
+ if not parts :
458
+ raise EmptyResultSet
458
459
combinator_pipeline = parts .pop (0 )[0 ].get_pipeline () if parts else None
459
460
if self .query .combinator == "union" :
460
461
for part , collection in parts :
@@ -463,13 +464,26 @@ def get_combinator_queries(self):
463
464
)
464
465
if not self .query .combinator_all :
465
466
ids = {}
466
- annotation_group_idx = itertools .count (start = 1 )
467
- for _ , expr in self .get_columns ():
468
- alias , replacement = self ._get_group_alias_column (
469
- expr , annotation_group_idx
467
+ for alias , expr in self .get_columns ():
468
+ ids [alias ] = (
469
+ expr .as_mql (self , self .connection )
470
+ if isinstance (expr , Col | Ref )
471
+ else f"${ alias } "
470
472
)
471
- ids [alias ] = expr .as_mql (self , self .connection )
472
473
combinator_pipeline .append ({"$group" : {"_id" : ids }})
474
+ projected_fields = defaultdict (dict )
475
+ for key in ids :
476
+ value = f"$_id.{ key } "
477
+ if self .GROUP_SEPARATOR in key :
478
+ table , field = key .split (self .GROUP_SEPARATOR )
479
+ projected_fields [table ][field ] = value
480
+ else :
481
+ projected_fields [key ] = value
482
+ # Convert defaultdict to dict so it doesn't appear as
483
+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
484
+ combinator_pipeline .append ({"$addFields" : dict (projected_fields )})
485
+ if "_id" not in projected_fields :
486
+ combinator_pipeline .append ({"$unset" : "_id" })
473
487
else :
474
488
raise NotSupportedError (f"Combinator { self .query .combinator } isn't supported." )
475
489
return combinator_pipeline
0 commit comments