@@ -349,7 +349,6 @@ def build_query(self, columns=None):
349
349
"""Check if the query is supported and prepare a MongoQuery."""
350
350
self .check_query ()
351
351
query = self .query_class (self )
352
- query .lookup_pipeline = self .get_lookup_pipeline ()
353
352
ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
354
353
query .project_fields = self .get_project_fields (columns , ordering_fields )
355
354
query .ordering = sort_ordering
@@ -359,13 +358,21 @@ def build_query(self, columns=None):
359
358
extra_fields += ordering_fields
360
359
if extra_fields :
361
360
query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
362
- where = self .get_where ()
363
- try :
364
- expr = where .as_mql (self , self .connection ) if where else {}
365
- except FullResultSet :
366
- query .mongo_query = {}
361
+ if self .query .combinator :
362
+ if not getattr (self .connection .features , f"supports_select_{ self .query .combinator } " ):
363
+ raise NotSupportedError (
364
+ f"{ self .query .combinator } is not supported on this database backend."
365
+ )
366
+ query .combinator_pipeline = self .get_combinator_queries ()
367
367
else :
368
- query .mongo_query = {"$expr" : expr }
368
+ query .lookup_pipeline = self .get_lookup_pipeline ()
369
+ where = self .get_where ()
370
+ try :
371
+ expr = where .as_mql (self , self .connection ) if where else {}
372
+ except FullResultSet :
373
+ query .mongo_query = {}
374
+ else :
375
+ query .mongo_query = {"$expr" : expr }
369
376
return query
370
377
371
378
def get_columns (self ):
@@ -412,6 +419,61 @@ def collection_name(self):
412
419
def collection (self ):
413
420
return self .connection .get_collection (self .collection_name )
414
421
422
+ def get_combinator_queries (self ):
423
+ parts = []
424
+ compilers = [
425
+ query .get_compiler (self .using , self .connection , self .elide_empty )
426
+ for query in self .query .combined_queries
427
+ ]
428
+ for compiler_ in compilers :
429
+ try :
430
+ # If the columns list is limited, then all combined queries
431
+ # must have the same columns list. Set the selects defined on
432
+ # the query on all combined queries, if not already set.
433
+ if not compiler_ .query .values_select and self .query .values_select :
434
+ compiler_ .query = compiler_ .query .clone ()
435
+ compiler_ .query .set_values (
436
+ (
437
+ * self .query .extra_select ,
438
+ * self .query .values_select ,
439
+ * self .query .annotation_select ,
440
+ )
441
+ )
442
+ compiler_ .pre_sql_setup (with_col_aliases = False )
443
+ # Avoid $project (columns=None) if unneeded.
444
+ columns = (
445
+ compiler_ .get_columns ()
446
+ if compiler_ .query .annotations or not compiler_ .query .default_cols
447
+ else None
448
+ )
449
+ parts .append ((compiler_ .build_query (columns ), compiler_ .collection_name ))
450
+
451
+ except EmptyResultSet :
452
+ # Omit the empty queryset with UNION and with DIFFERENCE if the
453
+ # first queryset is nonempty.
454
+ if self .query .combinator == "union" :
455
+ continue
456
+ raise
457
+
458
+ combinator_pipeline = parts .pop (0 )[0 ].get_pipeline () if parts else None
459
+ if self .query .combinator == "union" :
460
+ for part , collection in parts :
461
+ combinator_pipeline .append (
462
+ {"$unionWith" : {"coll" : collection , "pipeline" : part .get_pipeline ()}}
463
+ )
464
+ if not self .query .combinator_all :
465
+ ids = {}
466
+ annotation_group_idx = itertools .count (start = 1 )
467
+ for _ , expr in self .get_columns ():
468
+ alias , replacement = self ._get_group_alias_column (
469
+ expr , annotation_group_idx
470
+ )
471
+ ids [alias ] = expr .as_mql (self , self .connection )
472
+ combinator_pipeline .append ({"$group" : {"_id" : ids }})
473
+ else :
474
+ raise NotSupportedError (f"Combinator { self .query .combinator } isn't supported." )
475
+ return combinator_pipeline
476
+
415
477
def get_lookup_pipeline (self ):
416
478
result = []
417
479
for alias in tuple (self .query .alias_map ):
0 commit comments