7
7
from django .db import IntegrityError , NotSupportedError
8
8
from django .db .models import Count
9
9
from django .db .models .aggregates import Aggregate , Variance
10
- from django .db .models .expressions import Case , Col , Ref , Value , When
10
+ from django .db .models .expressions import Case , Col , OrderBy , Ref , Value , When
11
11
from django .db .models .functions .comparison import Coalesce
12
12
from django .db .models .functions .math import Power
13
13
from django .db .models .lookups import IsNull
@@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs):
32
32
# A list of OrderBy objects for this query.
33
33
self .order_by_objs = None
34
34
35
+ def _unfold_column (self , col ):
36
+ """
37
+ Flatten a field by returning its target or by replacing dots with
38
+ GROUP_SEPARATOR for foreign fields.
39
+ """
40
+ if self .collection_name == col .alias :
41
+ return col .target .column
42
+ # If this is a foreign field, replace the normal dot (.) with
43
+ # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44
+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
45
+
46
+ def _fold_columns (self , unfold_columns ):
47
+ """
48
+ Convert flat columns into a nested dictionary, grouping fields by
49
+ table name.
50
+ """
51
+ result = defaultdict (dict )
52
+ for key in unfold_columns :
53
+ value = f"$_id.{ key } "
54
+ if self .GROUP_SEPARATOR in key :
55
+ table , field = key .split (self .GROUP_SEPARATOR )
56
+ result [table ][field ] = value
57
+ else :
58
+ result [key ] = value
59
+ # Convert defaultdict to dict so it doesn't appear as
60
+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
61
+ return dict (result )
62
+
35
63
def _get_group_alias_column (self , expr , annotation_group_idx ):
36
64
"""Generate a dummy field for use in the ids fields in $group."""
37
65
replacement = None
@@ -42,11 +70,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
42
70
alias = f"__annotation_group{ next (annotation_group_idx )} "
43
71
col = self ._get_column_from_expression (expr , alias )
44
72
replacement = col
45
- if self .collection_name == col .alias :
46
- return col .target .column , replacement
47
- # If this is a foreign field, replace the normal dot (.) with
48
- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49
- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
73
+ return self ._unfold_column (col ), replacement
50
74
51
75
def _get_column_from_expression (self , expr , alias ):
52
76
"""
@@ -186,17 +210,8 @@ def _build_aggregation_pipeline(self, ids, group):
186
210
else :
187
211
group ["_id" ] = ids
188
212
pipeline .append ({"$group" : group })
189
- projected_fields = defaultdict (dict )
190
- for key in ids :
191
- value = f"$_id.{ key } "
192
- if self .GROUP_SEPARATOR in key :
193
- table , field = key .split (self .GROUP_SEPARATOR )
194
- projected_fields [table ][field ] = value
195
- else :
196
- projected_fields [key ] = value
197
- # Convert defaultdict to dict so it doesn't appear as
198
- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
199
- pipeline .append ({"$addFields" : dict (projected_fields )})
213
+ projected_fields = self ._fold_columns (ids )
214
+ pipeline .append ({"$addFields" : projected_fields })
200
215
if "_id" not in projected_fields :
201
216
pipeline .append ({"$unset" : "_id" })
202
217
return pipeline
@@ -349,23 +364,30 @@ def build_query(self, columns=None):
349
364
"""Check if the query is supported and prepare a MongoQuery."""
350
365
self .check_query ()
351
366
query = self .query_class (self )
352
- query .lookup_pipeline = self .get_lookup_pipeline ()
353
367
ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
354
- query .project_fields = self .get_project_fields (columns , ordering_fields )
355
368
query .ordering = sort_ordering
356
- # If columns is None, then get_project_fields() won't add
357
- # ordering_fields to $project. Use $addFields (extra_fields) instead.
358
- if columns is None :
359
- extra_fields += ordering_fields
369
+ if self .query .combinator :
370
+ if not getattr (self .connection .features , f"supports_select_{ self .query .combinator } " ):
371
+ raise NotSupportedError (
372
+ f"{ self .query .combinator } is not supported on this database backend."
373
+ )
374
+ query .combinator_pipeline = self .get_combinator_queries ()
375
+ else :
376
+ query .project_fields = self .get_project_fields (columns , ordering_fields )
377
+ # If columns is None, then get_project_fields() won't add
378
+ # ordering_fields to $project. Use $addFields (extra_fields) instead.
379
+ if columns is None :
380
+ extra_fields += ordering_fields
381
+ query .lookup_pipeline = self .get_lookup_pipeline ()
382
+ where = self .get_where ()
383
+ try :
384
+ expr = where .as_mql (self , self .connection ) if where else {}
385
+ except FullResultSet :
386
+ query .mongo_query = {}
387
+ else :
388
+ query .mongo_query = {"$expr" : expr }
360
389
if extra_fields :
361
390
query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
362
- where = self .get_where ()
363
- try :
364
- expr = where .as_mql (self , self .connection ) if where else {}
365
- except FullResultSet :
366
- query .mongo_query = {}
367
- else :
368
- query .mongo_query = {"$expr" : expr }
369
391
return query
370
392
371
393
def get_columns (self ):
@@ -391,6 +413,9 @@ def project_field(column):
391
413
if hasattr (column , "target" ):
392
414
# column is a Col.
393
415
target = column .target .column
416
+ # Handle Order By columns as refs columns.
417
+ elif isinstance (column , OrderBy ) and isinstance (column .expression , Ref ):
418
+ target = column .expression .refs
394
419
else :
395
420
# column is a Transform in values()/values_list() that needs a
396
421
# name for $proj.
@@ -412,6 +437,75 @@ def collection_name(self):
412
437
def collection (self ):
413
438
return self .connection .get_collection (self .collection_name )
414
439
440
+ def get_combinator_queries (self ):
441
+ parts = []
442
+ compilers = [
443
+ query .get_compiler (self .using , self .connection , self .elide_empty )
444
+ for query in self .query .combined_queries
445
+ ]
446
+ main_query_columns = self .get_columns ()
447
+ main_query_fields , _ = zip (* main_query_columns , strict = True )
448
+ for compiler_ in compilers :
449
+ try :
450
+ # If the columns list is limited, then all combined queries
451
+ # must have the same columns list. Set the selects defined on
452
+ # the query on all combined queries, if not already set.
453
+ if not compiler_ .query .values_select and self .query .values_select :
454
+ compiler_ .query = compiler_ .query .clone ()
455
+ compiler_ .query .set_values (
456
+ (
457
+ * self .query .extra_select ,
458
+ * self .query .values_select ,
459
+ * self .query .annotation_select ,
460
+ )
461
+ )
462
+ compiler_ .pre_sql_setup ()
463
+ columns = compiler_ .get_columns ()
464
+ parts .append ((compiler_ .build_query (columns ), compiler_ , columns ))
465
+ except EmptyResultSet :
466
+ # Omit the empty queryset with UNION.
467
+ if self .query .combinator == "union" :
468
+ continue
469
+ raise
470
+ # Raise EmptyResultSet if all the combinator queries are empty.
471
+ if not parts :
472
+ raise EmptyResultSet
473
+ # Make the combinator's stages.
474
+ combinator_pipeline = None
475
+ for part , compiler_ , columns in parts :
476
+ inner_pipeline = part .get_pipeline ()
477
+ # Standardize result fields.
478
+ fields = {}
479
+ # When a .count() is called, the main_query_field has length 1
480
+ # otherwise it has the same length as columns.
481
+ for alias , (ref , expr ) in zip (main_query_fields , columns , strict = False ):
482
+ if isinstance (expr , Col ) and expr .alias != compiler_ .collection_name :
483
+ fields [expr .alias ] = 1
484
+ else :
485
+ fields [alias ] = f"${ ref } " if alias != ref else 1
486
+ inner_pipeline .append ({"$project" : fields })
487
+ # Combine query with the current combinator pipeline.
488
+ if combinator_pipeline :
489
+ combinator_pipeline .append (
490
+ {"$unionWith" : {"coll" : compiler_ .collection_name , "pipeline" : inner_pipeline }}
491
+ )
492
+ else :
493
+ combinator_pipeline = inner_pipeline
494
+ if not self .query .combinator_all :
495
+ ids = {}
496
+ for alias , expr in main_query_columns :
497
+ # Unfold foreign fields.
498
+ if isinstance (expr , Col ) and expr .alias != self .collection_name :
499
+ ids [self ._unfold_column (expr )] = expr .as_mql (self , self .connection )
500
+ else :
501
+ ids [alias ] = f"${ alias } "
502
+ combinator_pipeline .append ({"$group" : {"_id" : ids }})
503
+ projected_fields = self ._fold_columns (ids )
504
+ combinator_pipeline .append ({"$addFields" : projected_fields })
505
+ if "_id" not in projected_fields :
506
+ combinator_pipeline .append ({"$unset" : "_id" })
507
+ return combinator_pipeline
508
+
415
509
def get_lookup_pipeline (self ):
416
510
result = []
417
511
for alias in tuple (self .query .alias_map ):
0 commit comments