Skip to content

Commit 56cd604

Browse files
authored
refactor "group by" to use embedded documents
Foreign fields are computed as {'T3': {'age': '$T3.age'} instead of {'T3___age': '$T3.age'}.
1 parent b3fd2c4 commit 56cd604

File tree

1 file changed

+11
-40
lines changed

1 file changed

+11
-40
lines changed

django_mongodb/compiler.py

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class SQLCompiler(compiler.SQLCompiler):
2525
"""Base class for all Mongo compilers."""
2626

2727
query_class = MongoQuery
28-
GROUP_SEPARATOR = "___"
2928
PARENT_FIELD_TEMPLATE = "parent__field__{}"
3029

3130
def __init__(self, *args, **kwargs):
@@ -37,34 +36,6 @@ def __init__(self, *args, **kwargs):
3736
self.order_by_objs = None
3837
self.subqueries = []
3938

40-
def _unfold_column(self, col):
41-
"""
42-
Flatten a field by returning its target or by replacing dots with
43-
GROUP_SEPARATOR for foreign fields.
44-
"""
45-
if self.collection_name == col.alias:
46-
return col.target.column
47-
# If this is a foreign field, replace the normal dot (.) with
48-
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49-
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"
50-
51-
def _fold_columns(self, unfold_columns):
52-
"""
53-
Convert flat columns into a nested dictionary, grouping fields by
54-
table name.
55-
"""
56-
result = defaultdict(dict)
57-
for key in unfold_columns:
58-
value = f"$_id.{key}"
59-
if self.GROUP_SEPARATOR in key:
60-
table, field = key.split(self.GROUP_SEPARATOR)
61-
result[table][field] = value
62-
else:
63-
result[key] = value
64-
# Convert defaultdict to dict so it doesn't appear as
65-
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
66-
return dict(result)
67-
6839
def _get_group_alias_column(self, expr, annotation_group_idx):
6940
"""Generate a dummy field for use in the ids fields in $group."""
7041
replacement = None
@@ -75,7 +46,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
7546
alias = f"__annotation_group{next(annotation_group_idx)}"
7647
col = self._get_column_from_expression(expr, alias)
7748
replacement = col
78-
return self._unfold_column(col), replacement
49+
return col.target.column, replacement
7950

8051
def _get_column_from_expression(self, expr, alias):
8152
"""
@@ -198,18 +169,15 @@ def _get_group_id_expressions(self, order_by):
198169
else:
199170
annotation_group_idx = itertools.count(start=1)
200171
ids = {}
172+
columns = []
201173
for col in group_expressions:
202174
alias, replacement = self._get_group_alias_column(col, annotation_group_idx)
203-
try:
204-
ids[alias] = col.as_mql(self, self.connection)
205-
except EmptyResultSet:
206-
ids[alias] = Value(False).as_mql(self, self.connection)
207-
except FullResultSet:
208-
ids[alias] = Value(True).as_mql(self, self.connection)
175+
columns.append((alias, col))
209176
if replacement is not None:
210177
replacements[col] = replacement
211178
if isinstance(col, Ref):
212179
replacements[col.source] = replacement
180+
ids = self.get_project_fields(tuple(columns), force_expression=True)
213181
return ids, replacements
214182

215183
def _build_aggregation_pipeline(self, ids, group):
@@ -234,7 +202,7 @@ def _build_aggregation_pipeline(self, ids, group):
234202
else:
235203
group["_id"] = ids
236204
pipeline.append({"$group": group})
237-
projected_fields = self._fold_columns(ids)
205+
projected_fields = {key: f"$_id.{key}" for key in ids}
238206
pipeline.append({"$addFields": projected_fields})
239207
if "_id" not in projected_fields:
240208
pipeline.append({"$unset": "_id"})
@@ -522,15 +490,18 @@ def get_combinator_queries(self):
522490
else:
523491
combinator_pipeline = inner_pipeline
524492
if not self.query.combinator_all:
525-
ids = {}
493+
ids = defaultdict(dict)
526494
for alias, expr in main_query_columns:
527495
# Unfold foreign fields.
528496
if isinstance(expr, Col) and expr.alias != self.collection_name:
529-
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
497+
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
530498
else:
531499
ids[alias] = f"${alias}"
500+
# Convert defaultdict to dict so it doesn't appear as
501+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
502+
ids = dict(ids)
532503
combinator_pipeline.append({"$group": {"_id": ids}})
533-
projected_fields = self._fold_columns(ids)
504+
projected_fields = {key: f"$_id.{key}" for key in ids}
534505
combinator_pipeline.append({"$addFields": projected_fields})
535506
if "_id" not in projected_fields:
536507
combinator_pipeline.append({"$unset": "_id"})

0 commit comments

Comments
 (0)