-
Notifications
You must be signed in to change notification settings - Fork 26
add support for QuerySet.aggregate() #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from copy import deepcopy | ||
|
||
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance | ||
from django.db.models.expressions import Case, Value, When | ||
from django.db.models.lookups import Exact | ||
from django.db.models.sql.where import WhereNode | ||
|
||
from .query_utils import process_lhs | ||
|
||
# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower(). | ||
MONGO_AGGREGATIONS = {Count: "sum"} | ||
|
||
|
||
def aggregate( | ||
self, | ||
compiler, | ||
connection, | ||
operator=None, | ||
resolve_inner_expression=False, | ||
**extra_context, # noqa: ARG001 | ||
): | ||
if self.filter: | ||
node = self.copy() | ||
node.filter = None | ||
source_expressions = node.get_source_expressions() | ||
condition = When(self.filter, then=source_expressions[0]) | ||
node.set_source_expressions([Case(condition)] + source_expressions[1:]) | ||
else: | ||
node = self | ||
lhs_mql = process_lhs(node, compiler, connection) | ||
if resolve_inner_expression: | ||
return lhs_mql | ||
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower()) | ||
return {f"${operator}": lhs_mql} | ||
|
||
|
||
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 | ||
""" | ||
When resolve_inner_expression=True, return the MQL that resolves as a | ||
value. This is used to count different elements, so the inner values are | ||
returned to be pushed into a set. | ||
""" | ||
if not self.distinct or resolve_inner_expression: | ||
if self.filter: | ||
node = self.copy() | ||
node.filter = None | ||
source_expressions = node.get_source_expressions() | ||
filter_ = deepcopy(self.filter) | ||
filter_.add( | ||
WhereNode([Exact(source_expressions[0], Value(None))], negated=True), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry @timgraham I deleted your comment by mistake, the answer is: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that both branches are excluding null values. My question is how does the query end up with null values? I could probably work through some tests to understand it better. Thought you might be able to give a quick example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, there are like 3 places with nulls. 😬 or maybe two. Select count(name) filter (where surname = 'Lupi')
from T1 Here we have to sum of the elements that fulfill the filter, we handle it with a case and the first (and only) source_expression would be the transformed (idk if there is others, but I just copy the mechanism from source) I will explain it one by one. But we have to change the exact(value, None) for IsNull(value, True) they are not the same. |
||
filter_.default, | ||
) | ||
condition = When(filter_, then=Value(1)) | ||
node.set_source_expressions([Case(condition)] + source_expressions[1:]) | ||
inner_expression = process_lhs(node, compiler, connection) | ||
else: | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]} | ||
inner_expression = { | ||
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1} | ||
} | ||
if resolve_inner_expression: | ||
return inner_expression | ||
return {"$sum": inner_expression} | ||
# If distinct=True or resolve_inner_expression=False, sum the size of the | ||
# set. | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
# None shouldn't be counted, so subtract 1 if it's present. | ||
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} | ||
return {"$add": [{"$size": lhs_mql}, exits_null]} | ||
|
||
|
||
def stddev_variance(self, compiler, connection, **extra_context): | ||
if self.function.endswith("_SAMP"): | ||
operator = "stdDevSamp" | ||
elif self.function.endswith("_POP"): | ||
operator = "stdDevPop" | ||
return aggregate(self, compiler, connection, operator=operator, **extra_context) | ||
|
||
|
||
def register_aggregates(): | ||
Aggregate.as_mql = aggregate | ||
Count.as_mql = count | ||
StdDev.as_mql = stddev_variance | ||
Variance.as_mql = stddev_variance |
Uh oh!
There was an error while loading. Please reload this page.