Skip to content

refactor "group by" to use embedded documents #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 11 additions & 40 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class SQLCompiler(compiler.SQLCompiler):
"""Base class for all Mongo compilers."""

query_class = MongoQuery
GROUP_SEPARATOR = "___"
PARENT_FIELD_TEMPLATE = "parent__field__{}"

def __init__(self, *args, **kwargs):
Expand All @@ -37,34 +36,6 @@ def __init__(self, *args, **kwargs):
self.order_by_objs = None
self.subqueries = []

def _unfold_column(self, col):
"""
Flatten a field by returning its target or by replacing dots with
GROUP_SEPARATOR for foreign fields.
"""
if self.collection_name == col.alias:
return col.target.column
# If this is a foreign field, replace the normal dot (.) with
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"

def _fold_columns(self, unfold_columns):
"""
Convert flat columns into a nested dictionary, grouping fields by
table name.
"""
result = defaultdict(dict)
for key in unfold_columns:
value = f"$_id.{key}"
if self.GROUP_SEPARATOR in key:
table, field = key.split(self.GROUP_SEPARATOR)
result[table][field] = value
else:
result[key] = value
# Convert defaultdict to dict so it doesn't appear as
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
return dict(result)

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
replacement = None
Expand All @@ -75,7 +46,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
alias = f"__annotation_group{next(annotation_group_idx)}"
col = self._get_column_from_expression(expr, alias)
replacement = col
return self._unfold_column(col), replacement
return col.target.column, replacement

def _get_column_from_expression(self, expr, alias):
"""
Expand Down Expand Up @@ -198,18 +169,15 @@ def _get_group_id_expressions(self, order_by):
else:
annotation_group_idx = itertools.count(start=1)
ids = {}
columns = []
for col in group_expressions:
alias, replacement = self._get_group_alias_column(col, annotation_group_idx)
try:
ids[alias] = col.as_mql(self, self.connection)
except EmptyResultSet:
ids[alias] = Value(False).as_mql(self, self.connection)
except FullResultSet:
ids[alias] = Value(True).as_mql(self, self.connection)
columns.append((alias, col))
if replacement is not None:
replacements[col] = replacement
if isinstance(col, Ref):
replacements[col.source] = replacement
ids = self.get_project_fields(tuple(columns), force_expression=True)
return ids, replacements

def _build_aggregation_pipeline(self, ids, group):
Expand All @@ -234,7 +202,7 @@ def _build_aggregation_pipeline(self, ids, group):
else:
group["_id"] = ids
pipeline.append({"$group": group})
projected_fields = self._fold_columns(ids)
projected_fields = {key: f"$_id.{key}" for key in ids}
pipeline.append({"$addFields": projected_fields})
if "_id" not in projected_fields:
pipeline.append({"$unset": "_id"})
Expand Down Expand Up @@ -522,15 +490,18 @@ def get_combinator_queries(self):
else:
combinator_pipeline = inner_pipeline
if not self.query.combinator_all:
ids = {}
ids = defaultdict(dict)
for alias, expr in main_query_columns:
# Unfold foreign fields.
if isinstance(expr, Col) and expr.alias != self.collection_name:
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
else:
ids[alias] = f"${alias}"
# Convert defaultdict to dict so it doesn't appear as
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
ids = dict(ids)
combinator_pipeline.append({"$group": {"_id": ids}})
projected_fields = self._fold_columns(ids)
projected_fields = {key: f"$_id.{key}" for key in ids}
combinator_pipeline.append({"$addFields": projected_fields})
if "_id" not in projected_fields:
combinator_pipeline.append({"$unset": "_id"})
Expand Down