Skip to content

Commit 4277f6c

Browse files
committed
Refactor.
1 parent 1311c43 commit 4277f6c

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

django_mongodb/compiler.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
3232
# A list of OrderBy objects for this query.
3333
self.order_by_objs = None
3434

35+
def _unfold_column(self, col):
36+
"""
37+
Flattens a field by returning its target or by replacing dots with GROUP_SEPARATOR
38+
for foreign fields.
39+
"""
40+
if self.collection_name == col.alias:
41+
return col.target.column
42+
# If this is a foreign field, replace the normal dot (.) with
43+
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44+
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"
45+
46+
def _fold_columns(self, unfold_columns):
47+
"""
48+
Converts flat columns into a nested dictionary, grouping fields by table names.
49+
"""
50+
result = defaultdict(dict)
51+
for key in unfold_columns:
52+
value = f"$_id.{key}"
53+
if self.GROUP_SEPARATOR in key:
54+
table, field = key.split(self.GROUP_SEPARATOR)
55+
result[table][field] = value
56+
else:
57+
result[key] = value
58+
# Convert defaultdict to dict so it doesn't appear as
59+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
60+
return dict(result)
61+
3562
def _get_group_alias_column(self, expr, annotation_group_idx):
3663
"""Generate a dummy field for use in the ids fields in $group."""
3764
replacement = None
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4269
alias = f"__annotation_group{next(annotation_group_idx)}"
4370
col = self._get_column_from_expression(expr, alias)
4471
replacement = col
45-
if self.collection_name == col.alias:
46-
return col.target.column, replacement
47-
# If this is a foreign field, replace the normal dot (.) with
48-
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49-
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}", replacement
72+
return self._unfold_column(col), replacement
5073

5174
def _get_column_from_expression(self, expr, alias):
5275
"""
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186209
else:
187210
group["_id"] = ids
188211
pipeline.append({"$group": group})
189-
projected_fields = defaultdict(dict)
190-
for key in ids:
191-
value = f"$_id.{key}"
192-
if self.GROUP_SEPARATOR in key:
193-
table, field = key.split(self.GROUP_SEPARATOR)
194-
projected_fields[table][field] = value
195-
else:
196-
projected_fields[key] = value
197-
# Convert defaultdict to dict so it doesn't appear as
198-
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
199-
pipeline.append({"$addFields": dict(projected_fields)})
212+
projected_fields = self._fold_columns(ids)
213+
pipeline.append({"$addFields": projected_fields})
200214
if "_id" not in projected_fields:
201215
pipeline.append({"$unset": "_id"})
202216
return pipeline
@@ -453,8 +467,7 @@ def get_combinator_queries(self):
453467
parts.append((compiler_.build_query(columns), compiler_.collection_name))
454468

455469
except EmptyResultSet:
456-
# Omit the empty queryset with UNION and with DIFFERENCE if the
457-
# first queryset is nonempty.
470+
# Omit the empty queryset with UNION.
458471
if self.query.combinator == "union":
459472
continue
460473
raise
@@ -470,25 +483,14 @@ def get_combinator_queries(self):
470483
if not self.query.combinator_all:
471484
ids = {}
472485
for alias, expr in main_query_columns:
473-
collection = expr.alias if isinstance(expr, Col) else None
474-
if collection and collection != self.collection_name:
475-
ids[
476-
f"{expr.alias}{self.GROUP_SEPARATOR}{expr.target.column}"
477-
] = expr.as_mql(self, self.connection)
486+
# Unfold foreign fields.
487+
if isinstance(expr, Col) and expr.alias != self.collection_name:
488+
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
478489
else:
479490
ids[alias] = f"${alias}"
480491
combinator_pipeline.append({"$group": {"_id": ids}})
481-
projected_fields = defaultdict(dict)
482-
for key in ids:
483-
value = f"$_id.{key}"
484-
if self.GROUP_SEPARATOR in key:
485-
table, field = key.split(self.GROUP_SEPARATOR)
486-
projected_fields[table][field] = value
487-
else:
488-
projected_fields[key] = value
489-
# Convert defaultdict to dict so it doesn't appear as
490-
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
491-
combinator_pipeline.append({"$addFields": dict(projected_fields)})
492+
projected_fields = self._fold_columns(ids)
493+
combinator_pipeline.append({"$addFields": projected_fields})
492494
if "_id" not in projected_fields:
493495
combinator_pipeline.append({"$unset": "_id"})
494496
else:

0 commit comments

Comments
 (0)