Skip to content

add support for QuerySet.aggregate() #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
- name: Run tests
run: >
python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2
aggregation
annotations
auth_tests.test_models.UserManagerTestCase
backends.base.test_base.DatabaseWrapperTests
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ Migrations for 'admin':
## Known issues and limitations

- The following `QuerySet` methods aren't supported:
- `aggregate()`
- `bulk_update()`
- `dates()`
- `datetimes()`
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

check_django_compatability()

from .aggregates import register_aggregates # noqa: E402
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .lookups import register_lookups # noqa: E402
from .query import register_nodes # noqa: E402

register_aggregates()
register_expressions()
register_fields()
register_functions()
Expand Down
85 changes: 85 additions & 0 deletions django_mongodb/aggregates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from copy import deepcopy

from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
from django.db.models.expressions import Case, Value, When
from django.db.models.lookups import Exact
from django.db.models.sql.where import WhereNode

from .query_utils import process_lhs

# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
MONGO_AGGREGATIONS = {Count: "sum"}


def aggregate(
self,
compiler,
connection,
operator=None,
resolve_inner_expression=False,
**extra_context, # noqa: ARG001
):
if self.filter:
node = self.copy()
node.filter = None
source_expressions = node.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
node.set_source_expressions([Case(condition)] + source_expressions[1:])
else:
node = self
lhs_mql = process_lhs(node, compiler, connection)
if resolve_inner_expression:
return lhs_mql
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
return {f"${operator}": lhs_mql}


def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
"""
When resolve_inner_expression=True, return the MQL that resolves as a
value. This is used to count different elements, so the inner values are
returned to be pushed into a set.
"""
if not self.distinct or resolve_inner_expression:
if self.filter:
node = self.copy()
node.filter = None
source_expressions = node.get_source_expressions()
filter_ = deepcopy(self.filter)
filter_.add(
WhereNode([Exact(source_expressions[0], Value(None))], negated=True),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @timgraham I deleted your comment by mistake, the answer is:
The count only counts values if they aren't none. if the expression result is a string, number or something it sums as 1. Looking again the code, maybe there is a bug when there is filter and distinct options. Will check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that both branches are excluding null values. My question is how does the query end up with null values? I could probably work through some tests to understand it better. Thought you might be able to give a quick example.

Copy link
Collaborator Author

@WaVEV WaVEV Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, there are like 3 places with nulls. 😬 or maybe two.
The example when this code is use is:

Select count(name) filter (where surname = 'Lupi')
from T1

Here we have to sum of the elements that fulfill the filter, we handle it with a case and the first (and only) source_expression would be the transformed (idk if there is others, but I just copy the mechanism from source)

I will explain it one by one.

But we have to change the exact(value, None) for IsNull(value, True) they are not the same.

filter_.default,
)
condition = When(filter_, then=Value(1))
node.set_source_expressions([Case(condition)] + source_expressions[1:])
inner_expression = process_lhs(node, compiler, connection)
else:
lhs_mql = process_lhs(self, compiler, connection)
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
inner_expression = {
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
}
if resolve_inner_expression:
return inner_expression
return {"$sum": inner_expression}
# If distinct=True or resolve_inner_expression=False, sum the size of the
# set.
lhs_mql = process_lhs(self, compiler, connection)
# None shouldn't be counted, so subtract 1 if it's present.
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
return {"$add": [{"$size": lhs_mql}, exits_null]}


def stddev_variance(self, compiler, connection, **extra_context):
if self.function.endswith("_SAMP"):
operator = "stdDevSamp"
elif self.function.endswith("_POP"):
operator = "stdDevPop"
return aggregate(self, compiler, connection, operator=operator, **extra_context)


def register_aggregates():
Aggregate.as_mql = aggregate
Count.as_mql = count
StdDev.as_mql = stddev_variance
Variance.as_mql = stddev_variance
Loading