Skip to content

Commit 86683de

Browse files
committed
Refactor group by as embedded doc.
1 parent e971120 commit 86683de

File tree

1 file changed

+9
-39
lines changed

1 file changed

+9
-39
lines changed

django_mongodb/compiler.py

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class SQLCompiler(compiler.SQLCompiler):
2525
"""Base class for all Mongo compilers."""
2626

2727
query_class = MongoQuery
28-
GROUP_SEPARATOR = "___"
2928
PARENT_FIELD_TEMPLATE = "parent__field__{}"
3029

3130
def __init__(self, *args, **kwargs):
@@ -37,34 +36,6 @@ def __init__(self, *args, **kwargs):
3736
self.order_by_objs = None
3837
self.subqueries = []
3938

40-
def _unfold_column(self, col):
41-
"""
42-
Flatten a field by returning its target or by replacing dots with
43-
GROUP_SEPARATOR for foreign fields.
44-
"""
45-
if self.collection_name == col.alias:
46-
return col.target.column
47-
# If this is a foreign field, replace the normal dot (.) with
48-
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49-
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"
50-
51-
def _fold_columns(self, unfold_columns):
52-
"""
53-
Convert flat columns into a nested dictionary, grouping fields by
54-
table name.
55-
"""
56-
result = defaultdict(dict)
57-
for key in unfold_columns:
58-
value = f"$_id.{key}"
59-
if self.GROUP_SEPARATOR in key:
60-
table, field = key.split(self.GROUP_SEPARATOR)
61-
result[table][field] = value
62-
else:
63-
result[key] = value
64-
# Convert defaultdict to dict so it doesn't appear as
65-
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
66-
return dict(result)
67-
6839
def _get_group_alias_column(self, expr, annotation_group_idx):
6940
"""Generate a dummy field for use in the ids fields in $group."""
7041
replacement = None
@@ -75,7 +46,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
7546
alias = f"__annotation_group{next(annotation_group_idx)}"
7647
col = self._get_column_from_expression(expr, alias)
7748
replacement = col
78-
return self._unfold_column(col), replacement
49+
return col.target.column, replacement
7950

8051
def _get_column_from_expression(self, expr, alias):
8152
"""
@@ -198,18 +169,15 @@ def _get_group_id_expressions(self, order_by):
198169
else:
199170
annotation_group_idx = itertools.count(start=1)
200171
ids = {}
172+
columns = []
201173
for col in group_expressions:
202174
alias, replacement = self._get_group_alias_column(col, annotation_group_idx)
203-
try:
204-
ids[alias] = col.as_mql(self, self.connection)
205-
except EmptyResultSet:
206-
ids[alias] = Value(False).as_mql(self, self.connection)
207-
except FullResultSet:
208-
ids[alias] = Value(True).as_mql(self, self.connection)
175+
columns.append((alias, col))
209176
if replacement is not None:
210177
replacements[col] = replacement
211178
if isinstance(col, Ref):
212179
replacements[col.source] = replacement
180+
ids = self.get_project_fields(tuple(columns), force_expression=True)
213181
return ids, replacements
214182

215183
def _build_aggregation_pipeline(self, ids, group):
@@ -234,7 +202,7 @@ def _build_aggregation_pipeline(self, ids, group):
234202
else:
235203
group["_id"] = ids
236204
pipeline.append({"$group": group})
237-
projected_fields = self._fold_columns(ids)
205+
projected_fields = {key: f"$_id.{key}" for key in ids}
238206
pipeline.append({"$addFields": projected_fields})
239207
if "_id" not in projected_fields:
240208
pipeline.append({"$unset": "_id"})
@@ -526,11 +494,13 @@ def get_combinator_queries(self):
526494
for alias, expr in main_query_columns:
527495
# Unfold foreign fields.
528496
if isinstance(expr, Col) and expr.alias != self.collection_name:
529-
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
497+
if expr.alias not in ids:
498+
ids[expr.alias] = {}
499+
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
530500
else:
531501
ids[alias] = f"${alias}"
532502
combinator_pipeline.append({"$group": {"_id": ids}})
533-
projected_fields = self._fold_columns(ids)
503+
projected_fields = {key: f"$_id.{key}" for key in ids}
534504
combinator_pipeline.append({"$addFields": projected_fields})
535505
if "_id" not in projected_fields:
536506
combinator_pipeline.append({"$unset": "_id"})

0 commit comments

Comments
 (0)