@@ -25,7 +25,6 @@ class SQLCompiler(compiler.SQLCompiler):
25
25
"""Base class for all Mongo compilers."""
26
26
27
27
query_class = MongoQuery
28
- GROUP_SEPARATOR = "___"
29
28
PARENT_FIELD_TEMPLATE = "parent__field__{}"
30
29
31
30
def __init__ (self , * args , ** kwargs ):
@@ -37,34 +36,6 @@ def __init__(self, *args, **kwargs):
37
36
self .order_by_objs = None
38
37
self .subqueries = []
39
38
40
- def _unfold_column (self , col ):
41
- """
42
- Flatten a field by returning its target or by replacing dots with
43
- GROUP_SEPARATOR for foreign fields.
44
- """
45
- if self .collection_name == col .alias :
46
- return col .target .column
47
- # If this is a foreign field, replace the normal dot (.) with
48
- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49
- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
50
-
51
- def _fold_columns (self , unfold_columns ):
52
- """
53
- Convert flat columns into a nested dictionary, grouping fields by
54
- table name.
55
- """
56
- result = defaultdict (dict )
57
- for key in unfold_columns :
58
- value = f"$_id.{ key } "
59
- if self .GROUP_SEPARATOR in key :
60
- table , field = key .split (self .GROUP_SEPARATOR )
61
- result [table ][field ] = value
62
- else :
63
- result [key ] = value
64
- # Convert defaultdict to dict so it doesn't appear as
65
- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
66
- return dict (result )
67
-
68
39
def _get_group_alias_column (self , expr , annotation_group_idx ):
69
40
"""Generate a dummy field for use in the ids fields in $group."""
70
41
replacement = None
@@ -75,7 +46,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
75
46
alias = f"__annotation_group{ next (annotation_group_idx )} "
76
47
col = self ._get_column_from_expression (expr , alias )
77
48
replacement = col
78
- return self . _unfold_column ( col ) , replacement
49
+ return col . target . column , replacement
79
50
80
51
def _get_column_from_expression (self , expr , alias ):
81
52
"""
@@ -198,18 +169,15 @@ def _get_group_id_expressions(self, order_by):
198
169
else :
199
170
annotation_group_idx = itertools .count (start = 1 )
200
171
ids = {}
172
+ columns = []
201
173
for col in group_expressions :
202
174
alias , replacement = self ._get_group_alias_column (col , annotation_group_idx )
203
- try :
204
- ids [alias ] = col .as_mql (self , self .connection )
205
- except EmptyResultSet :
206
- ids [alias ] = Value (False ).as_mql (self , self .connection )
207
- except FullResultSet :
208
- ids [alias ] = Value (True ).as_mql (self , self .connection )
175
+ columns .append ((alias , col ))
209
176
if replacement is not None :
210
177
replacements [col ] = replacement
211
178
if isinstance (col , Ref ):
212
179
replacements [col .source ] = replacement
180
+ ids = self .get_project_fields (tuple (columns ), force_expression = True )
213
181
return ids , replacements
214
182
215
183
def _build_aggregation_pipeline (self , ids , group ):
@@ -234,7 +202,7 @@ def _build_aggregation_pipeline(self, ids, group):
234
202
else :
235
203
group ["_id" ] = ids
236
204
pipeline .append ({"$group" : group })
237
- projected_fields = self . _fold_columns ( ids )
205
+ projected_fields = { key : f"$_id. { key } " for key in ids }
238
206
pipeline .append ({"$addFields" : projected_fields })
239
207
if "_id" not in projected_fields :
240
208
pipeline .append ({"$unset" : "_id" })
@@ -522,15 +490,18 @@ def get_combinator_queries(self):
522
490
else :
523
491
combinator_pipeline = inner_pipeline
524
492
if not self .query .combinator_all :
525
- ids = {}
493
+ ids = defaultdict ( dict )
526
494
for alias , expr in main_query_columns :
527
495
# Unfold foreign fields.
528
496
if isinstance (expr , Col ) and expr .alias != self .collection_name :
529
- ids [self . _unfold_column ( expr ) ] = expr .as_mql (self , self .connection )
497
+ ids [expr . alias ][ expr . target . column ] = expr .as_mql (self , self .connection )
530
498
else :
531
499
ids [alias ] = f"${ alias } "
500
+ # Convert defaultdict to dict so it doesn't appear as
501
+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
502
+ ids = dict (ids )
532
503
combinator_pipeline .append ({"$group" : {"_id" : ids }})
533
- projected_fields = self . _fold_columns ( ids )
504
+ projected_fields = { key : f"$_id. { key } " for key in ids }
534
505
combinator_pipeline .append ({"$addFields" : projected_fields })
535
506
if "_id" not in projected_fields :
536
507
combinator_pipeline .append ({"$unset" : "_id" })
0 commit comments