7
7
from django .db import IntegrityError , NotSupportedError
8
8
from django .db .models import Count
9
9
from django .db .models .aggregates import Aggregate , Variance
10
- from django .db .models .expressions import Case , Col , Ref , Value , When
10
+ from django .db .models .expressions import Case , Col , OrderBy , Ref , Value , When
11
11
from django .db .models .functions .comparison import Coalesce
12
12
from django .db .models .functions .math import Power
13
13
from django .db .models .lookups import IsNull
@@ -349,23 +349,30 @@ def build_query(self, columns=None):
349
349
"""Check if the query is supported and prepare a MongoQuery."""
350
350
self .check_query ()
351
351
query = self .query_class (self )
352
- query .lookup_pipeline = self .get_lookup_pipeline ()
353
352
ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
354
- query .project_fields = self .get_project_fields (columns , ordering_fields )
355
353
query .ordering = sort_ordering
356
- # If columns is None, then get_project_fields() won't add
357
- # ordering_fields to $project. Use $addFields (extra_fields) instead.
358
- if columns is None :
359
- extra_fields += ordering_fields
354
+ if self .query .combinator :
355
+ if not getattr (self .connection .features , f"supports_select_{ self .query .combinator } " ):
356
+ raise NotSupportedError (
357
+ f"{ self .query .combinator } is not supported on this database backend."
358
+ )
359
+ query .combinator_pipeline = self .get_combinator_queries ()
360
+ else :
361
+ query .project_fields = self .get_project_fields (columns , ordering_fields )
362
+ # If columns is None, then get_project_fields() won't add
363
+ # ordering_fields to $project. Use $addFields (extra_fields) instead.
364
+ if columns is None :
365
+ extra_fields += ordering_fields
366
+ query .lookup_pipeline = self .get_lookup_pipeline ()
367
+ where = self .get_where ()
368
+ try :
369
+ expr = where .as_mql (self , self .connection ) if where else {}
370
+ except FullResultSet :
371
+ query .mongo_query = {}
372
+ else :
373
+ query .mongo_query = {"$expr" : expr }
360
374
if extra_fields :
361
375
query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
362
- where = self .get_where ()
363
- try :
364
- expr = where .as_mql (self , self .connection ) if where else {}
365
- except FullResultSet :
366
- query .mongo_query = {}
367
- else :
368
- query .mongo_query = {"$expr" : expr }
369
376
return query
370
377
371
378
def get_columns (self ):
@@ -391,6 +398,9 @@ def project_field(column):
391
398
if hasattr (column , "target" ):
392
399
# column is a Col.
393
400
target = column .target .column
401
+ # Handle Order by columns as refs columns.
402
+ elif isinstance (column , OrderBy ) and isinstance (column .expression , Ref ):
403
+ target = column .expression .refs
394
404
else :
395
405
# column is a Transform in values()/values_list() that needs a
396
406
# name for $proj.
@@ -412,6 +422,79 @@ def collection_name(self):
412
422
def collection (self ):
413
423
return self .connection .get_collection (self .collection_name )
414
424
425
+ def get_combinator_queries (self ):
426
+ parts = []
427
+ compilers = [
428
+ query .get_compiler (self .using , self .connection , self .elide_empty )
429
+ for query in self .query .combined_queries
430
+ ]
431
+ main_query_columns = self .get_columns ()
432
+ main_query_fields , _ = zip (* main_query_columns , strict = True )
433
+ for compiler_ in compilers :
434
+ try :
435
+ # If the columns list is limited, then all combined queries
436
+ # must have the same columns list. Set the selects defined on
437
+ # the query on all combined queries, if not already set.
438
+ if not compiler_ .query .values_select and self .query .values_select :
439
+ compiler_ .query = compiler_ .query .clone ()
440
+ compiler_ .query .set_values (
441
+ (
442
+ * self .query .extra_select ,
443
+ * self .query .values_select ,
444
+ * self .query .annotation_select ,
445
+ )
446
+ )
447
+ compiler_ .pre_sql_setup ()
448
+ # Standardize columns as main query required.
449
+ columns = compiler_ .get_columns ()
450
+ if self .query .values_select :
451
+ _ , exprs = zip (* columns , strict = True )
452
+ columns = tuple (zip (main_query_fields , exprs , strict = True ))
453
+ parts .append ((compiler_ .build_query (columns ), compiler_ .collection_name ))
454
+
455
+ except EmptyResultSet :
456
+ # Omit the empty queryset with UNION and with DIFFERENCE if the
457
+ # first queryset is nonempty.
458
+ if self .query .combinator == "union" :
459
+ continue
460
+ raise
461
+ # Raise EmptyResultSet if all the combinator queries are empty.
462
+ if not parts :
463
+ raise EmptyResultSet
464
+ combinator_pipeline = parts .pop (0 )[0 ].get_pipeline () if parts else None
465
+ if self .query .combinator == "union" :
466
+ for part , collection in parts :
467
+ combinator_pipeline .append (
468
+ {"$unionWith" : {"coll" : collection , "pipeline" : part .get_pipeline ()}}
469
+ )
470
+ if not self .query .combinator_all :
471
+ ids = {}
472
+ for alias , expr in main_query_columns :
473
+ collection = expr .alias if isinstance (expr , Col ) else None
474
+ if collection and collection != self .collection_name :
475
+ ids [
476
+ f"{ expr .alias } { self .GROUP_SEPARATOR } { expr .target .column } "
477
+ ] = expr .as_mql (self , self .connection )
478
+ else :
479
+ ids [alias ] = f"${ alias } "
480
+ combinator_pipeline .append ({"$group" : {"_id" : ids }})
481
+ projected_fields = defaultdict (dict )
482
+ for key in ids :
483
+ value = f"$_id.{ key } "
484
+ if self .GROUP_SEPARATOR in key :
485
+ table , field = key .split (self .GROUP_SEPARATOR )
486
+ projected_fields [table ][field ] = value
487
+ else :
488
+ projected_fields [key ] = value
489
+ # Convert defaultdict to dict so it doesn't appear as
490
+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
491
+ combinator_pipeline .append ({"$addFields" : dict (projected_fields )})
492
+ if "_id" not in projected_fields :
493
+ combinator_pipeline .append ({"$unset" : "_id" })
494
+ else :
495
+ raise NotSupportedError (f"Combinator { self .query .combinator } isn't supported." )
496
+ return combinator_pipeline
497
+
415
498
def get_lookup_pipeline (self ):
416
499
result = []
417
500
for alias in tuple (self .query .alias_map ):
0 commit comments