7
7
8
8
from .query_utils import process_lhs
9
9
10
- MONGO_AGGREGATIONS = {
11
- Count : "sum" ,
12
- StdDev : "stdDev" , # Samp or Pop suffix added in aggregate().
13
- Variance : "stdDev" , # Likewise.
14
- }
10
+ MONGO_AGGREGATIONS = {Count : "sum" }
15
11
16
12
17
- def aggregate (self , compiler , connection , ** extra_context ): # noqa: ARG001
13
+ def aggregate (self , compiler , connection , operator = None , ** extra_context ): # noqa: ARG001
18
14
if self .filter :
19
15
node = self .copy ()
20
16
node .filter = None
@@ -24,12 +20,7 @@ def aggregate(self, compiler, connection, **extra_context): # noqa: ARG001
24
20
else :
25
21
node = self
26
22
lhs_mql = process_lhs (node , compiler , connection )
27
- operator = MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
28
- # Add suffixes to StdDev/Variance.
29
- if self .function .endswith ("_SAMP" ):
30
- operator += "Samp"
31
- elif self .function .endswith ("_POP" ):
32
- operator += "Pop"
23
+ operator = operator or MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
33
24
return {f"${ operator } " : lhs_mql }
34
25
35
26
@@ -69,6 +60,16 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
69
60
return {"$add" : [{"$size" : lhs_mql }, exits_null ]}
70
61
71
62
63
+ def stddev_variance (self , compiler , connection , ** extra_context ):
64
+ if self .function .endswith ("_SAMP" ):
65
+ operator = "stdDevSamp"
66
+ elif self .function .endswith ("_POP" ):
67
+ operator = "stdDevPop"
68
+ return aggregate (self , compiler , connection , operator = operator , ** extra_context )
69
+
70
+
72
71
def register_aggregates ():
73
72
Aggregate .as_mql = aggregate
74
73
Count .as_mql = count
74
+ StdDev .as_mql = stddev_variance
75
+ Variance .as_mql = stddev_variance
0 commit comments