|
1 | 1 | from django.db import NotSupportedError
|
2 | 2 | from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
|
3 |
| -from django.db.models.expressions import Func |
| 3 | +from django.db.models.expressions import Case, Func, Value, When |
4 | 4 | from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
|
5 | 5 | from django.db.models.functions.datetime import (
|
6 | 6 | Extract,
|
|
41 | 41 | Min: "min",
|
42 | 42 | StdDev: "stddev",
|
43 | 43 | Sum: "sum",
|
44 |
| - Variance: "variance", |
| 44 | + Variance: "stdDevPop", |
45 | 45 | }
|
46 | 46 |
|
47 | 47 |
|
|
69 | 69 | }
|
70 | 70 |
|
71 | 71 |
|
72 |
| -def aggregate(self, compiler, connection): # noqa: ARG001 |
73 |
| - pass |
| 72 | +def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001 |
| 73 | + if self.filter: |
| 74 | + copy = self.copy() |
| 75 | + copy.filter = None |
| 76 | + source_expressions = copy.get_source_expressions() |
| 77 | + condition = When(self.filter, then=source_expressions[0]) |
| 78 | + copy.set_source_expressions([Case(condition)] + source_expressions[1:]) |
| 79 | + node = copy |
| 80 | + else: |
| 81 | + node = self |
| 82 | + lhs_mql = process_lhs(node, compiler, connection) |
| 83 | + operator = MONGO_AGGREGATION.get(self.__class__) |
| 84 | + return {f"${operator}": lhs_mql} |
74 | 85 |
|
75 | 86 |
|
76 | 87 | def cast(self, compiler, connection):
|
@@ -103,6 +114,19 @@ def cot(self, compiler, connection):
|
103 | 114 | return {"$divide": [1, {"$tan": lhs_mql}]}
|
104 | 115 |
|
105 | 116 |
|
| 117 | +def count(self, compiler, connection, **extra_context): # noqa: ARG001 |
| 118 | + if self.filter: |
| 119 | + copy = self.copy() |
| 120 | + copy.filter = None |
| 121 | + source_expressions = copy.get_source_expressions() |
| 122 | + condition = When(self.filter, then=Value(1)) |
| 123 | + copy.set_source_expressions([Case(condition)] + source_expressions[1:]) |
| 124 | + lhs_mql = process_lhs(copy, compiler, connection) |
| 125 | + else: |
| 126 | + lhs_mql = Value(1).as_mql(compiler, connection) |
| 127 | + return {"$sum": lhs_mql} |
| 128 | + |
| 129 | + |
106 | 130 | def extract(self, compiler, connection):
|
107 | 131 | lhs_mql = process_lhs(self, compiler, connection)
|
108 | 132 | operator = EXTRACT_OPERATORS.get(self.lookup_name)
|
@@ -208,6 +232,7 @@ def register_functions():
|
208 | 232 | Concat.as_mql = concat
|
209 | 233 | ConcatPair.as_mql = concat_pair
|
210 | 234 | Cot.as_mql = cot
|
| 235 | + Count.as_mql = count |
211 | 236 | Extract.as_mql = extract
|
212 | 237 | Func.as_mql = func
|
213 | 238 | Left.as_mql = left
|
|
0 commit comments