@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
32
32
# A list of OrderBy objects for this query.
33
33
self .order_by_objs = None
34
34
35
+ def _unfold_column (self , col ):
36
+ """
37
+ Flattens a field by returning its target or by replacing dots with GROUP_SEPARATOR
38
+ for foreign fields.
39
+ """
40
+ if self .collection_name == col .alias :
41
+ return col .target .column
42
+ # If this is a foreign field, replace the normal dot (.) with
43
+ # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44
+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
45
+
46
+ def _fold_columns (self , unfold_columns ):
47
+ """
48
+ Converts flat columns into a nested dictionary, grouping fields by table names.
49
+ """
50
+ result = defaultdict (dict )
51
+ for key in unfold_columns :
52
+ value = f"$_id.{ key } "
53
+ if self .GROUP_SEPARATOR in key :
54
+ table , field = key .split (self .GROUP_SEPARATOR )
55
+ result [table ][field ] = value
56
+ else :
57
+ result [key ] = value
58
+ # Convert defaultdict to dict so it doesn't appear as
59
+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
60
+ return dict (result )
61
+
35
62
def _get_group_alias_column (self , expr , annotation_group_idx ):
36
63
"""Generate a dummy field for use in the ids fields in $group."""
37
64
replacement = None
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
42
69
alias = f"__annotation_group{ next (annotation_group_idx )} "
43
70
col = self ._get_column_from_expression (expr , alias )
44
71
replacement = col
45
- if self .collection_name == col .alias :
46
- return col .target .column , replacement
47
- # If this is a foreign field, replace the normal dot (.) with
48
- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49
- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
72
+ return self ._unfold_column (col ), replacement
50
73
51
74
def _get_column_from_expression (self , expr , alias ):
52
75
"""
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186
209
else :
187
210
group ["_id" ] = ids
188
211
pipeline .append ({"$group" : group })
189
- projected_fields = defaultdict (dict )
190
- for key in ids :
191
- value = f"$_id.{ key } "
192
- if self .GROUP_SEPARATOR in key :
193
- table , field = key .split (self .GROUP_SEPARATOR )
194
- projected_fields [table ][field ] = value
195
- else :
196
- projected_fields [key ] = value
197
- # Convert defaultdict to dict so it doesn't appear as
198
- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
199
- pipeline .append ({"$addFields" : dict (projected_fields )})
212
+ projected_fields = self ._fold_columns (ids )
213
+ pipeline .append ({"$addFields" : projected_fields })
200
214
if "_id" not in projected_fields :
201
215
pipeline .append ({"$unset" : "_id" })
202
216
return pipeline
@@ -453,8 +467,7 @@ def get_combinator_queries(self):
453
467
parts .append ((compiler_ .build_query (columns ), compiler_ .collection_name ))
454
468
455
469
except EmptyResultSet :
456
- # Omit the empty queryset with UNION and with DIFFERENCE if the
457
- # first queryset is nonempty.
470
+ # Omit the empty queryset with UNION.
458
471
if self .query .combinator == "union" :
459
472
continue
460
473
raise
@@ -470,25 +483,14 @@ def get_combinator_queries(self):
470
483
if not self .query .combinator_all :
471
484
ids = {}
472
485
for alias , expr in main_query_columns :
473
- collection = expr .alias if isinstance (expr , Col ) else None
474
- if collection and collection != self .collection_name :
475
- ids [
476
- f"{ expr .alias } { self .GROUP_SEPARATOR } { expr .target .column } "
477
- ] = expr .as_mql (self , self .connection )
486
+ # Unfold foreign fields.
487
+ if isinstance (expr , Col ) and expr .alias != self .collection_name :
488
+ ids [self ._unfold_column (expr )] = expr .as_mql (self , self .connection )
478
489
else :
479
490
ids [alias ] = f"${ alias } "
480
491
combinator_pipeline .append ({"$group" : {"_id" : ids }})
481
- projected_fields = defaultdict (dict )
482
- for key in ids :
483
- value = f"$_id.{ key } "
484
- if self .GROUP_SEPARATOR in key :
485
- table , field = key .split (self .GROUP_SEPARATOR )
486
- projected_fields [table ][field ] = value
487
- else :
488
- projected_fields [key ] = value
489
- # Convert defaultdict to dict so it doesn't appear as
490
- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
491
- combinator_pipeline .append ({"$addFields" : dict (projected_fields )})
492
+ projected_fields = self ._fold_columns (ids )
493
+ combinator_pipeline .append ({"$addFields" : projected_fields })
492
494
if "_id" not in projected_fields :
493
495
combinator_pipeline .append ({"$unset" : "_id" })
494
496
else :
0 commit comments