|
| 1 | +from copy import deepcopy |
| 2 | + |
1 | 3 | from django.db import NotSupportedError
|
2 | 4 | from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
|
3 | 5 | from django.db.models.expressions import Case, Func, Star, Value, When
|
|
32 | 34 | Trim,
|
33 | 35 | Upper,
|
34 | 36 | )
|
| 37 | +from django.db.models.lookups import Exact |
| 38 | +from django.db.models.query_utils import Q |
35 | 39 |
|
36 | 40 | from .query_utils import process_lhs
|
37 | 41 |
|
@@ -115,21 +119,34 @@ def cot(self, compiler, connection):
|
115 | 119 | return {"$divide": [1, {"$tan": lhs_mql}]}
|
116 | 120 |
|
117 | 121 |
|
118 |
| -def count(self, compiler, connection, **extra_context): # noqa: ARG001 |
119 |
| - if self.filter: |
120 |
| - copy = self.copy() |
121 |
| - copy.filter = None |
122 |
| - source_expressions = copy.get_source_expressions() |
123 |
| - condition = When(self.filter, then=Value(1)) |
124 |
| - copy.set_source_expressions([Case(condition)] + source_expressions[1:]) |
125 |
| - node = copy |
126 |
| - cond = process_lhs(node, compiler, connection) |
127 |
| - else: |
128 |
| - node = self |
129 |
| - lhs_mql = process_lhs(self, compiler, connection) |
130 |
| - null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]} |
131 |
| - cond = {"$cond": {"if": null_cond, "then": 0, "else": 1}} |
132 |
| - return {"$sum": cond} |
| 122 | +def count(self, compiler, connection, **extra_context): |
| 123 | + if not self.distinct or extra_context.get("force_filters"): |
| 124 | + if self.filter: |
| 125 | + copy = self.copy() |
| 126 | + copy.filter = None |
| 127 | + source_expressions = copy.get_source_expressions() |
| 128 | + filter_ = deepcopy(self.filter) |
| 129 | + filter_.add(~Q(Exact(source_expressions[0], Value(None))), filter_.default) |
| 130 | + condition = When(filter_, then=Value(1)) |
| 131 | + copy.set_source_expressions([Case(condition)] + source_expressions[1:]) |
| 132 | + node = copy |
| 133 | + cond = process_lhs(node, compiler, connection) |
| 134 | + else: |
| 135 | + node = self |
| 136 | + lhs_mql = process_lhs(self, compiler, connection) |
| 137 | + null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]} |
| 138 | + cond = { |
| 139 | + "$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1} |
| 140 | + } |
| 141 | + operator = "$sum" |
| 142 | + return {operator: cond} |
| 143 | + |
| 144 | + operator = "$size" |
| 145 | + lhs_mql = process_lhs(self, compiler, connection) |
| 146 | + |
| 147 | + exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} |
| 148 | + |
| 149 | + return {"$add": [{"$size": lhs_mql}, exits_null]} |
133 | 150 |
|
134 | 151 |
|
135 | 152 | def extract(self, compiler, connection):
|
|
0 commit comments