Skip to content

Commit d611631

Browse files
committed
Clean ups.
1 parent 3d7bdf3 commit d611631

21 files changed

+986
-996
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
MONGO_AGGREGATIONS = {Count: "sum"}
99

1010

11-
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
11+
def aggregate(
12+
self,
13+
compiler,
14+
connection,
15+
operator=None,
16+
resolve_inner_expression=False,
17+
**extra_context, # noqa: ARG001
18+
):
1219
if self.filter:
1320
node = self.copy()
1421
node.filter = None
@@ -24,7 +31,7 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
2431
return {f"${operator}": lhs_mql}
2532

2633

27-
def count(self, compiler, connection, resolve_inner_expression=False):
34+
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
2835
"""
2936
When resolve_inner_expression=True, return the MQL that resolves as a
3037
value. This is used to count different elements, so the inner values are
@@ -57,12 +64,12 @@ def count(self, compiler, connection, resolve_inner_expression=False):
5764
return {"$add": [{"$size": lhs_mql}, exits_null]}
5865

5966

60-
def stddev_variance(self, compiler, connection):
67+
def stddev_variance(self, compiler, connection, **extra_context):
6168
if self.function.endswith("_SAMP"):
6269
operator = "stdDevSamp"
6370
elif self.function.endswith("_POP"):
6471
operator = "stdDevPop"
65-
return aggregate(self, compiler, connection, operator=operator)
72+
return aggregate(self, compiler, connection, operator=operator, **extra_context)
6673

6774

6875
def register_aggregates():

django_mongodb_backend/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _isnull_operator_match(a, b):
113113
return {"$or": [{a: {"$exists": False}}, {a: None}]}
114114
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}
115115

116-
mongo_operators_expr = {
116+
mongo_expr_operators = {
117117
"exact": lambda a, b: {"$eq": [a, b]},
118118
"gt": lambda a, b: {"$gt": [a, b]},
119119
"gte": lambda a, b: {"$gte": [a, b]},
@@ -153,7 +153,8 @@ def range_match(a, b):
153153
return {"$literal": True}
154154
return {"$and": conditions}
155155

156-
mongo_operators_match = {
156+
# match, path, find? don't know which name use.
157+
mongo_match_operators = {
157158
"exact": lambda a, b: {a: b},
158159
"gt": lambda a, b: {a: {"$gt": b}},
159160
"gte": lambda a, b: {a: {"$gte": b}},

django_mongodb_backend/compiler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,11 +481,11 @@ def build_query(self, columns=None):
481481
query.lookup_pipeline = self.get_lookup_pipeline()
482482
where = self.get_where()
483483
try:
484-
expr = where.as_mql(self, self.connection, as_path=True) if where else {}
484+
match = where.as_mql(self, self.connection, as_path=True) if where else {}
485485
except FullResultSet:
486486
query.match_mql = {}
487487
else:
488-
query.match_mql = expr
488+
query.match_mql = match
489489
if extra_fields:
490490
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
491491
query.subqueries = self.subqueries
@@ -643,9 +643,7 @@ def get_combinator_queries(self):
643643
for alias, expr in self.columns:
644644
# Unfold foreign fields.
645645
if isinstance(expr, Col) and expr.alias != self.collection_name:
646-
ids[expr.alias][expr.target.column] = expr.as_mql(
647-
self, self.connection, as_path=False
648-
)
646+
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
649647
else:
650648
ids[alias] = f"${alias}"
651649
# Convert defaultdict to dict so it doesn't appear as
@@ -716,9 +714,9 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
716714
value = (
717715
False if empty_result_set_value is NotImplemented else empty_result_set_value
718716
)
719-
fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False)
717+
fields[collection][name] = Value(value).as_mql(self, self.connection)
720718
except FullResultSet:
721-
fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False)
719+
fields[collection][name] = Value(True).as_mql(self, self.connection)
722720
# Annotations (stored in None) and the main collection's fields
723721
# should appear in the top-level of the fields dict.
724722
fields.update(fields.pop(None, {}))

django_mongodb_backend/expressions/builtins.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,20 @@
2929
from ..query_utils import process_lhs
3030

3131

32+
def base_expression(self, compiler, connection, as_path=False, **extra):
33+
if as_path and hasattr(self, "as_mql_path") and getattr(self, "can_use_path", False):
34+
return self.as_mql_path(compiler, connection, **extra)
35+
36+
expr = self.as_mql_expr(compiler, connection, **extra)
37+
return {"$expr": expr} if as_path else expr
38+
39+
3240
def case(self, compiler, connection):
3341
case_parts = []
3442
for case in self.cases:
3543
case_mql = {}
3644
try:
37-
case_mql["case"] = case.as_mql(compiler, connection, as_path=False)
45+
case_mql["case"] = case.as_mql(compiler, connection)
3846
except EmptyResultSet:
3947
continue
4048
except FullResultSet:
@@ -84,20 +92,20 @@ def col_pairs(self, compiler, connection, as_path=False):
8492
return cols[0].as_mql(compiler, connection, as_path=as_path)
8593

8694

87-
def combined_expression(self, compiler, connection, as_path=False):
95+
def combined_expression(self, compiler, connection):
8896
expressions = [
89-
self.lhs.as_mql(compiler, connection, as_path=as_path),
90-
self.rhs.as_mql(compiler, connection, as_path=as_path),
97+
self.lhs.as_mql(compiler, connection),
98+
self.rhs.as_mql(compiler, connection),
9199
]
92100
return connection.ops.combine_expression(self.connector, expressions)
93101

94102

95-
def expression_wrapper_expr(self, compiler, connection):
96-
return self.expression.as_mql(compiler, connection, as_path=False)
103+
def expression_wrapper(self, compiler, connection):
104+
return self.expression.as_mql(compiler, connection)
97105

98106

99-
def negated_expression_expr(self, compiler, connection):
100-
return {"$not": expression_wrapper_expr(self, compiler, connection)}
107+
def negated_expression(self, compiler, connection):
108+
return {"$not": expression_wrapper(self, compiler, connection)}
101109

102110

103111
def order_by(self, compiler, connection):
@@ -172,10 +180,10 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001
172180

173181
@property
174182
def ref_is_simple_column(self):
175-
return isinstance(self.source, Col) and self.source.alias is not None
183+
return self.source.is_simple_column
176184

177185

178-
def star(self, compiler, connection, as_path=False): # noqa: ARG001
186+
def star(self, compiler, connection): # noqa: ARG001
179187
return {"$literal": True}
180188

181189

@@ -190,11 +198,11 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None):
190198
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
191199
except EmptyResultSet:
192200
return Value(False).as_mql(compiler, connection)
193-
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
201+
return connection.mongo_expr_operators["isnull"](lhs_mql, False)
194202

195203

196-
def when(self, compiler, connection, as_path=False):
197-
return self.condition.as_mql(compiler, connection, as_path=as_path)
204+
def when(self, compiler, connection):
205+
return self.condition.as_mql(compiler, connection)
198206

199207

200208
def value(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -221,18 +229,6 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001
221229
return value
222230

223231

224-
def base_expression(self, compiler, connection, as_path=False, **extra):
225-
if (
226-
as_path
227-
and hasattr(self, "as_mql_path")
228-
and getattr(self, "is_simple_expression", lambda: False)()
229-
):
230-
return self.as_mql_path(compiler, connection, **extra)
231-
232-
expr = self.as_mql_expr(compiler, connection, **extra)
233-
return {"$expr": expr} if as_path else expr
234-
235-
236232
def register_expressions():
237233
BaseExpression.as_mql = base_expression
238234
BaseExpression.is_simple_column = False
@@ -243,15 +239,15 @@ def register_expressions():
243239
CombinedExpression.as_mql_expr = combined_expression
244240
Exists.as_mql_expr = exists
245241
ExpressionList.as_mql = process_lhs
246-
ExpressionWrapper.as_mql_expr = expression_wrapper_expr
247-
NegatedExpression.as_mql_expr = negated_expression_expr
242+
ExpressionWrapper.as_mql_expr = expression_wrapper
243+
NegatedExpression.as_mql_expr = negated_expression
248244
OrderBy.as_mql_expr = order_by
249245
Query.as_mql = query
250246
RawSQL.as_mql = raw_sql
251247
Ref.as_mql = ref
252248
Ref.is_simple_column = ref_is_simple_column
253249
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
254-
Star.as_mql = star
250+
Star.as_mql_expr = star
255251
Subquery.as_mql_expr = subquery
256-
When.as_mql = when
252+
When.as_mql_expr = when
257253
Value.as_mql = value

django_mongodb_backend/fields/array.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from django.db.models import Field, Func, IntegerField, Transform, Value
55
from django.db.models.fields.mixins import CheckFieldDefaultMixin
66
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup
7+
from django.utils.functional import cached_property
78
from django.utils.translation import gettext_lazy as _
89

910
from ..forms import SimpleArrayField
10-
from ..query_utils import process_lhs, process_rhs
11+
from ..query_utils import is_constant_value, process_lhs, process_rhs
1112
from ..utils import prefix_validation_error
1213
from ..validators import ArrayMaxLengthValidator, LengthValidator
1314

@@ -236,6 +237,20 @@ def as_mql_expr(self, compiler, connection):
236237
for expr in self.get_source_expressions()
237238
]
238239

240+
def as_mql_path(self, compiler, connection):
241+
return [
242+
expr.as_mql(compiler, connection, as_path=True)
243+
for expr in self.get_source_expressions()
244+
]
245+
246+
@cached_property
247+
def can_use_path(self):
248+
return all(is_constant_value(expr) for expr in self.get_source_expressions())
249+
250+
@property
251+
def is_simple_column(self):
252+
return False
253+
239254

240255
class ArrayRHSMixin:
241256
def __init__(self, lhs, rhs):
@@ -254,13 +269,6 @@ def __init__(self, lhs, rhs):
254269
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
255270
lookup_name = "contains"
256271

257-
def as_mql_path(self, compiler, connection):
258-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
259-
value = process_rhs(self, compiler, connection, as_path=True)
260-
if value is None:
261-
return False
262-
return {lhs_mql: {"$all": value}}
263-
264272
def as_mql_expr(self, compiler, connection):
265273
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
266274
value = process_rhs(self, compiler, connection, as_path=False)
@@ -272,6 +280,11 @@ def as_mql_expr(self, compiler, connection):
272280
]
273281
}
274282

283+
def as_mql_path(self, compiler, connection):
284+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
285+
value = process_rhs(self, compiler, connection, as_path=True)
286+
return {lhs_mql: {"$all": value}}
287+
275288

276289
@ArrayField.register_lookup
277290
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
@@ -333,11 +346,6 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
333346
},
334347
]
335348

336-
def as_mql_path(self, compiler, connection):
337-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
338-
value = process_rhs(self, compiler, connection, as_path=True)
339-
return {lhs_mql: {"$in": value}}
340-
341349
def as_mql_expr(self, compiler, connection):
342350
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
343351
value = process_rhs(self, compiler, connection, as_path=False)
@@ -348,6 +356,11 @@ def as_mql_expr(self, compiler, connection):
348356
]
349357
}
350358

359+
def as_mql_path(self, compiler, connection):
360+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
361+
value = process_rhs(self, compiler, connection, as_path=True)
362+
return {lhs_mql: {"$in": value}}
363+
351364

352365
@ArrayField.register_lookup
353366
class ArrayLenTransform(Transform):
@@ -381,21 +394,22 @@ def __init__(self, index, base_field, *args, **kwargs):
381394
self.index = index
382395
self.base_field = base_field
383396

384-
def is_simple_expression(self):
397+
@property
398+
def can_use_path(self):
385399
return self.is_simple_column
386400

387401
@property
388402
def is_simple_column(self):
389403
return self.lhs.is_simple_column
390404

391-
def as_mql_path(self, compiler, connection):
392-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
393-
return f"{lhs_mql}.{self.index}"
394-
395405
def as_mql_expr(self, compiler, connection):
396406
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
397407
return {"$arrayElemAt": [lhs_mql, self.index]}
398408

409+
def as_mql_path(self, compiler, connection):
410+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
411+
return f"{lhs_mql}.{self.index}"
412+
399413
@property
400414
def output_field(self):
401415
return self.base_field

django_mongodb_backend/fields/embedded_model.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from django.utils.functional import cached_property
99

1010
from .. import forms
11+
from ..query_utils import valid_path_key_name
1112

1213

1314
class EmbeddedModelField(models.Field):
@@ -167,14 +168,15 @@ def __init__(self, field, *args, **kwargs):
167168
def get_lookup(self, name):
168169
return self.field.get_lookup(name)
169170

170-
def is_simple_expression(self):
171+
@property
172+
def can_use_path(self):
171173
return self.is_simple_column
172174

173175
@cached_property
174176
def is_simple_column(self):
175177
previous = self
176-
while isinstance(previous, KeyTransform):
177-
if not previous.key_name.isalnum():
178+
while isinstance(previous, EmbeddedModelTransform):
179+
if not valid_path_key_name(previous._field.column):
178180
return False
179181
previous = previous.lhs
180182
return previous.is_simple_column
@@ -198,27 +200,27 @@ def get_transform(self, name):
198200
f"{suggestion}"
199201
)
200202

201-
def as_mql_path(self, compiler, connection):
202-
previous = self
203-
key_transforms = []
204-
while isinstance(previous, EmbeddedModelTransform):
205-
key_transforms.insert(0, previous.key_name)
206-
previous = previous.lhs
207-
mql = previous.as_mql(compiler, connection, as_path=True)
208-
mql_path = ".".join(key_transforms)
209-
return f"{mql}.{mql_path}"
210-
211-
def as_mql_expr(self, compiler, connection):
203+
def _get_target_path(self):
212204
previous = self
213205
columns = []
214206
while isinstance(previous, EmbeddedModelTransform):
215207
columns.insert(0, previous.field.column)
216208
previous = previous.lhs
217-
mql = previous.as_mql(compiler, connection)
218-
for column in columns:
219-
mql = {"$getField": {"input": mql, "field": column}}
209+
return columns, previous
210+
211+
def as_mql_expr(self, compiler, connection):
212+
columns, parent_field = self._get_target_path()
213+
mql = parent_field.as_mql(compiler, connection)
214+
for key in columns:
215+
mql = {"$getField": {"input": mql, "field": key}}
220216
return mql
221217

218+
def as_mql_path(self, compiler, connection):
219+
columns, parent_field = self._get_target_path()
220+
mql = parent_field.as_mql(compiler, connection, as_path=True)
221+
mql_path = ".".join(columns)
222+
return f"{mql}.{mql_path}"
223+
222224
@property
223225
def output_field(self):
224226
return self._field

0 commit comments

Comments
 (0)