7
7
from django .db import IntegrityError , NotSupportedError
8
8
from django .db .models import Count
9
9
from django .db .models .aggregates import Aggregate , Variance
10
- from django .db .models .expressions import Case , Col , Ref , Value , When
10
+ from django .db .models .expressions import Case , Col , OrderBy , Ref , Value , When
11
11
from django .db .models .functions .comparison import Coalesce
12
12
from django .db .models .functions .math import Power
13
13
from django .db .models .lookups import IsNull
@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
32
32
# A list of OrderBy objects for this query.
33
33
self .order_by_objs = None
34
34
35
+ def _unfold_column (self , col ):
36
+ """
37
+ Flatten a field by returning its target or by replacing dots with GROUP_SEPARATOR
38
+ for foreign fields.
39
+ """
40
+ if self .collection_name == col .alias :
41
+ return col .target .column
42
+ # If this is a foreign field, replace the normal dot (.) with
43
+ # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44
+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
45
+
46
+ def _fold_columns (self , unfold_columns ):
47
+ """
48
+ Convert flat columns into a nested dictionary, grouping fields by table names.
49
+ """
50
+ result = defaultdict (dict )
51
+ for key in unfold_columns :
52
+ value = f"$_id.{ key } "
53
+ if self .GROUP_SEPARATOR in key :
54
+ table , field = key .split (self .GROUP_SEPARATOR )
55
+ result [table ][field ] = value
56
+ else :
57
+ result [key ] = value
58
+ # Convert defaultdict to dict so it doesn't appear as
59
+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
60
+ return dict (result )
61
+
35
62
def _get_group_alias_column (self , expr , annotation_group_idx ):
36
63
"""Generate a dummy field for use in the ids fields in $group."""
37
64
replacement = None
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
42
69
alias = f"__annotation_group{ next (annotation_group_idx )} "
43
70
col = self ._get_column_from_expression (expr , alias )
44
71
replacement = col
45
- if self .collection_name == col .alias :
46
- return col .target .column , replacement
47
- # If this is a foreign field, replace the normal dot (.) with
48
- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49
- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
72
+ return self ._unfold_column (col ), replacement
50
73
51
74
def _get_column_from_expression (self , expr , alias ):
52
75
"""
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186
209
else :
187
210
group ["_id" ] = ids
188
211
pipeline .append ({"$group" : group })
189
- projected_fields = defaultdict (dict )
190
- for key in ids :
191
- value = f"$_id.{ key } "
192
- if self .GROUP_SEPARATOR in key :
193
- table , field = key .split (self .GROUP_SEPARATOR )
194
- projected_fields [table ][field ] = value
195
- else :
196
- projected_fields [key ] = value
197
- # Convert defaultdict to dict so it doesn't appear as
198
- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
199
- pipeline .append ({"$addFields" : dict (projected_fields )})
212
+ projected_fields = self ._fold_columns (ids )
213
+ pipeline .append ({"$addFields" : projected_fields })
200
214
if "_id" not in projected_fields :
201
215
pipeline .append ({"$unset" : "_id" })
202
216
return pipeline
@@ -349,23 +363,30 @@ def build_query(self, columns=None):
349
363
"""Check if the query is supported and prepare a MongoQuery."""
350
364
self .check_query ()
351
365
query = self .query_class (self )
352
- query .lookup_pipeline = self .get_lookup_pipeline ()
353
366
ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
354
- query .project_fields = self .get_project_fields (columns , ordering_fields )
355
367
query .ordering = sort_ordering
356
- # If columns is None, then get_project_fields() won't add
357
- # ordering_fields to $project. Use $addFields (extra_fields) instead.
358
- if columns is None :
359
- extra_fields += ordering_fields
368
+ if self .query .combinator :
369
+ if not getattr (self .connection .features , f"supports_select_{ self .query .combinator } " ):
370
+ raise NotSupportedError (
371
+ f"{ self .query .combinator } is not supported on this database backend."
372
+ )
373
+ query .combinator_pipeline = self .get_combinator_queries ()
374
+ else :
375
+ query .project_fields = self .get_project_fields (columns , ordering_fields )
376
+ # If columns is None, then get_project_fields() won't add
377
+ # ordering_fields to $project. Use $addFields (extra_fields) instead.
378
+ if columns is None :
379
+ extra_fields += ordering_fields
380
+ query .lookup_pipeline = self .get_lookup_pipeline ()
381
+ where = self .get_where ()
382
+ try :
383
+ expr = where .as_mql (self , self .connection ) if where else {}
384
+ except FullResultSet :
385
+ query .mongo_query = {}
386
+ else :
387
+ query .mongo_query = {"$expr" : expr }
360
388
if extra_fields :
361
389
query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
362
- where = self .get_where ()
363
- try :
364
- expr = where .as_mql (self , self .connection ) if where else {}
365
- except FullResultSet :
366
- query .mongo_query = {}
367
- else :
368
- query .mongo_query = {"$expr" : expr }
369
390
return query
370
391
371
392
def get_columns (self ):
@@ -391,6 +412,9 @@ def project_field(column):
391
412
if hasattr (column , "target" ):
392
413
# column is a Col.
393
414
target = column .target .column
415
+ # Handle Order By columns as refs columns.
416
+ elif isinstance (column , OrderBy ) and isinstance (column .expression , Ref ):
417
+ target = column .expression .refs
394
418
else :
395
419
# column is a Transform in values()/values_list() that needs a
396
420
# name for $proj.
@@ -412,6 +436,75 @@ def collection_name(self):
412
436
def collection (self ):
413
437
return self .connection .get_collection (self .collection_name )
414
438
439
+ def get_combinator_queries (self ):
440
+ parts = []
441
+ compilers = [
442
+ query .get_compiler (self .using , self .connection , self .elide_empty )
443
+ for query in self .query .combined_queries
444
+ ]
445
+ main_query_columns = self .get_columns ()
446
+ main_query_fields , _ = zip (* main_query_columns , strict = True )
447
+ for compiler_ in compilers :
448
+ try :
449
+ # If the columns list is limited, then all combined queries
450
+ # must have the same columns list. Set the selects defined on
451
+ # the query on all combined queries, if not already set.
452
+ if not compiler_ .query .values_select and self .query .values_select :
453
+ compiler_ .query = compiler_ .query .clone ()
454
+ compiler_ .query .set_values (
455
+ (
456
+ * self .query .extra_select ,
457
+ * self .query .values_select ,
458
+ * self .query .annotation_select ,
459
+ )
460
+ )
461
+ compiler_ .pre_sql_setup ()
462
+ columns = compiler_ .get_columns ()
463
+ parts .append ((compiler_ .build_query (columns ), compiler_ , columns ))
464
+ except EmptyResultSet :
465
+ # Omit the empty queryset with UNION.
466
+ if self .query .combinator == "union" :
467
+ continue
468
+ raise
469
+ # Raise EmptyResultSet if all the combinator queries are empty.
470
+ if not parts :
471
+ raise EmptyResultSet
472
+ # Make the combinator's stages.
473
+ combinator_pipeline = None
474
+ for part , compiler_ , columns in parts :
475
+ inner_pipeline = part .get_pipeline ()
476
+ # Standardize result fields.
477
+ fields = {}
478
+ # When a .count() is called, the main_query_field has length 1
479
+ # otherwise it has the same length as columns.
480
+ for alias , (ref , expr ) in zip (main_query_fields , columns , strict = False ):
481
+ if isinstance (expr , Col ) and expr .alias != compiler_ .collection_name :
482
+ fields [expr .alias ] = 1
483
+ else :
484
+ fields [alias ] = f"${ ref } " if alias != ref else 1
485
+ inner_pipeline .append ({"$project" : fields })
486
+ # Combine query with the current combinator pipeline.
487
+ if combinator_pipeline :
488
+ combinator_pipeline .append (
489
+ {"$unionWith" : {"coll" : compiler_ .collection_name , "pipeline" : inner_pipeline }}
490
+ )
491
+ else :
492
+ combinator_pipeline = inner_pipeline
493
+ if not self .query .combinator_all :
494
+ ids = {}
495
+ for alias , expr in main_query_columns :
496
+ # Unfold foreign fields.
497
+ if isinstance (expr , Col ) and expr .alias != self .collection_name :
498
+ ids [self ._unfold_column (expr )] = expr .as_mql (self , self .connection )
499
+ else :
500
+ ids [alias ] = f"${ alias } "
501
+ combinator_pipeline .append ({"$group" : {"_id" : ids }})
502
+ projected_fields = self ._fold_columns (ids )
503
+ combinator_pipeline .append ({"$addFields" : projected_fields })
504
+ if "_id" not in projected_fields :
505
+ combinator_pipeline .append ({"$unset" : "_id" })
506
+ return combinator_pipeline
507
+
415
508
def get_lookup_pipeline (self ):
416
509
result = []
417
510
for alias in tuple (self .query .alias_map ):
0 commit comments