Skip to content

Commit 59c8faf

Browse files
committed
Functional approach solution
1 parent e747973 commit 59c8faf

File tree

17 files changed

+572
-297
lines changed

17 files changed

+572
-297
lines changed

django_mongodb_backend/base.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .features import DatabaseFeatures
2121
from .introspection import DatabaseIntrospection
2222
from .operations import DatabaseOperations
23-
from .query_utils import regex_match
23+
from .query_utils import regex_expr, regex_match
2424
from .schema import DatabaseSchemaEditor
2525
from .utils import OperationDebugWrapper
2626
from .validation import DatabaseValidation
@@ -108,7 +108,12 @@ def _isnull_operator(a, b):
108108
}
109109
return is_null if b else {"$not": is_null}
110110

111-
mongo_operators = {
111+
def _isnull_operator_match(a, b):
112+
if b:
113+
return {"$or": [{a: {"$exists": False}}, {a: None}]}
114+
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}
115+
116+
mongo_operators_expr = {
112117
"exact": lambda a, b: {"$eq": [a, b]},
113118
"gt": lambda a, b: {"$gt": [a, b]},
114119
"gte": lambda a, b: {"$gte": [a, b]},
@@ -118,19 +123,56 @@ def _isnull_operator(a, b):
118123
"lte": lambda a, b: {
119124
"$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]
120125
},
121-
"in": lambda a, b: {"$in": [a, b]},
126+
"in": lambda a, b: {"$in": (a, b)},
122127
"isnull": _isnull_operator,
123128
"range": lambda a, b: {
124129
"$and": [
125130
{"$or": [DatabaseWrapper._isnull_operator(b[0], True), {"$gte": [a, b[0]]}]},
126131
{"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]},
127132
]
128133
},
129-
"iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True),
130-
"startswith": lambda a, b: regex_match(a, ("^", b)),
131-
"istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True),
132-
"endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})),
133-
"iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True),
134+
"iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True),
135+
"startswith": lambda a, b: regex_expr(a, ("^", b)),
136+
"istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True),
137+
"endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})),
138+
"iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True),
139+
"contains": lambda a, b: regex_expr(a, b),
140+
"icontains": lambda a, b: regex_expr(a, b, insensitive=True),
141+
"regex": lambda a, b: regex_expr(a, b),
142+
"iregex": lambda a, b: regex_expr(a, b, insensitive=True),
143+
}
144+
145+
def range_match(a, b):
146+
## TODO: MAKE A TEST TO TEST WHEN BOTH ENDS ARE NONE. WHAT SHALL I RETURN?
147+
conditions = []
148+
if b[0] is not None:
149+
conditions.append({a: {"$gte": b[0]}})
150+
if b[1] is not None:
151+
conditions.append({a: {"$lte": b[1]}})
152+
if not conditions:
153+
return {"$literal": True}
154+
return {"$and": conditions}
155+
156+
mongo_operators_match = {
157+
"exact": lambda a, b: {a: b},
158+
"gt": lambda a, b: {a: {"$gt": b}},
159+
"gte": lambda a, b: {a: {"$gte": b}},
160+
# MongoDB considers null less than zero. Exclude null values to match
161+
# SQL behavior.
162+
"lt": lambda a, b: {
163+
"$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
164+
},
165+
"lte": lambda a, b: {
166+
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
167+
},
168+
"in": lambda a, b: {a: {"$in": tuple(b)}},
169+
"isnull": _isnull_operator_match,
170+
"range": range_match,
171+
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),
172+
"startswith": lambda a, b: regex_match(a, f"^{b}"),
173+
"istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True),
174+
"endswith": lambda a, b: regex_match(a, f"{b}$"),
175+
"iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True),
134176
"contains": lambda a, b: regex_match(a, b),
135177
"icontains": lambda a, b: regex_match(a, b, insensitive=True),
136178
"regex": lambda a, b: regex_match(a, b),

django_mongodb_backend/compiler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,14 @@ def pre_sql_setup(self, with_col_aliases=False):
327327
pipeline = self._build_aggregation_pipeline(ids, group)
328328
if self.having:
329329
having = self.having.replace_expressions(all_replacements).as_mql(
330-
self, self.connection
330+
self, self.connection, as_path=True
331331
)
332332
# Add HAVING subqueries.
333333
for query in self.subqueries or ():
334334
pipeline.extend(query.get_pipeline())
335335
# Remove the added subqueries.
336336
self.subqueries = []
337-
pipeline.append({"$match": {"$expr": having}})
337+
pipeline.append({"$match": having})
338338
self.aggregation_pipeline = pipeline
339339
self.annotations = {
340340
target: expr.replace_expressions(all_replacements)
@@ -481,11 +481,11 @@ def build_query(self, columns=None):
481481
query.lookup_pipeline = self.get_lookup_pipeline()
482482
where = self.get_where()
483483
try:
484-
expr = where.as_mql(self, self.connection) if where else {}
484+
expr = where.as_mql(self, self.connection, as_path=True) if where else {}
485485
except FullResultSet:
486486
query.match_mql = {}
487487
else:
488-
query.match_mql = {"$expr": expr}
488+
query.match_mql = expr
489489
if extra_fields:
490490
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
491491
query.subqueries = self.subqueries
@@ -643,7 +643,9 @@ def get_combinator_queries(self):
643643
for alias, expr in self.columns:
644644
# Unfold foreign fields.
645645
if isinstance(expr, Col) and expr.alias != self.collection_name:
646-
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
646+
ids[expr.alias][expr.target.column] = expr.as_mql(
647+
self, self.connection, as_path=False
648+
)
647649
else:
648650
ids[alias] = f"${alias}"
649651
# Convert defaultdict to dict so it doesn't appear as
@@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
707709
# For brevity/simplicity, project {"field_name": 1}
708710
# instead of {"field_name": "$field_name"}.
709711
if isinstance(expr, Col) and name == expr.target.column and not force_expression
710-
else expr.as_mql(self, self.connection)
712+
else expr.as_mql(self, self.connection, as_path=False)
711713
)
712714
except EmptyResultSet:
713715
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
714716
value = (
715717
False if empty_result_set_value is NotImplemented else empty_result_set_value
716718
)
717-
fields[collection][name] = Value(value).as_mql(self, self.connection)
719+
fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False)
718720
except FullResultSet:
719-
fields[collection][name] = Value(True).as_mql(self, self.connection)
721+
fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False)
720722
# Annotations (stored in None) and the main collection's fields
721723
# should appear in the top-level of the fields dict.
722724
fields.update(fields.pop(None, {}))

django_mongodb_backend/expressions/builtins.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from django.core.exceptions import EmptyResultSet, FullResultSet
77
from django.db import NotSupportedError
88
from django.db.models.expressions import (
9+
BaseExpression,
910
Case,
1011
Col,
1112
ColPairs,
1213
CombinedExpression,
1314
Exists,
1415
ExpressionList,
1516
ExpressionWrapper,
17+
Func,
1618
NegatedExpression,
1719
OrderBy,
1820
RawSQL,
@@ -23,17 +25,20 @@
2325
Value,
2426
When,
2527
)
28+
from django.db.models.fields.json import KeyTransform
2629
from django.db.models.sql import Query
2730

28-
from ..query_utils import process_lhs
31+
from django_mongodb_backend.fields.array import Array
2932

33+
from ..query_utils import is_direct_value, process_lhs
3034

31-
def case(self, compiler, connection):
35+
36+
def case(self, compiler, connection, as_path=False):
3237
case_parts = []
3338
for case in self.cases:
3439
case_mql = {}
3540
try:
36-
case_mql["case"] = case.as_mql(compiler, connection)
41+
case_mql["case"] = case.as_mql(compiler, connection, as_path=False)
3742
except EmptyResultSet:
3843
continue
3944
except FullResultSet:
@@ -45,12 +50,16 @@ def case(self, compiler, connection):
4550
default_mql = self.default.as_mql(compiler, connection)
4651
if not case_parts:
4752
return default_mql
48-
return {
53+
expr = {
4954
"$switch": {
5055
"branches": case_parts,
5156
"default": default_mql,
5257
}
5358
}
59+
if as_path:
60+
return {"$expr": expr}
61+
62+
return expr
5463

5564

5665
def col(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -76,34 +85,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7685
return f"{prefix}{self.target.column}"
7786

7887

79-
def col_pairs(self, compiler, connection):
88+
def col_pairs(self, compiler, connection, as_path=False):
8089
cols = self.get_cols()
8190
if len(cols) > 1:
8291
raise NotSupportedError("ColPairs is not supported.")
83-
return cols[0].as_mql(compiler, connection)
92+
return cols[0].as_mql(compiler, connection, as_path=as_path)
8493

8594

86-
def combined_expression(self, compiler, connection):
95+
def combined_expression(self, compiler, connection, as_path=False):
8796
expressions = [
88-
self.lhs.as_mql(compiler, connection),
89-
self.rhs.as_mql(compiler, connection),
97+
self.lhs.as_mql(compiler, connection, as_path=as_path),
98+
self.rhs.as_mql(compiler, connection, as_path=as_path),
9099
]
91100
return connection.ops.combine_expression(self.connector, expressions)
92101

93102

94-
def expression_wrapper(self, compiler, connection):
95-
return self.expression.as_mql(compiler, connection)
103+
def expression_wrapper(self, compiler, connection, as_path=False):
104+
return self.expression.as_mql(compiler, connection, as_path=as_path)
96105

97106

98-
def negated_expression(self, compiler, connection):
99-
return {"$not": expression_wrapper(self, compiler, connection)}
107+
def negated_expression(self, compiler, connection, as_path=False):
108+
return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)}
100109

101110

102111
def order_by(self, compiler, connection):
103112
return self.expression.as_mql(compiler, connection)
104113

105114

106-
def query(self, compiler, connection, get_wrapping_pipeline=None):
115+
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
107116
subquery_compiler = self.get_compiler(connection=connection)
108117
subquery_compiler.pre_sql_setup(with_col_aliases=False)
109118
field_name, expr = subquery_compiler.columns[0]
@@ -145,14 +154,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
145154
# Erase project_fields since the required value is projected above.
146155
subquery.project_fields = None
147156
compiler.subqueries.append(subquery)
157+
if as_path:
158+
return f"{table_output}.{field_name}"
148159
return f"${table_output}.{field_name}"
149160

150161

151162
def raw_sql(self, compiler, connection): # noqa: ARG001
152163
raise NotSupportedError("RawSQL is not supported on MongoDB.")
153164

154165

155-
def ref(self, compiler, connection): # noqa: ARG001
166+
def ref(self, compiler, connection, as_path=False): # noqa: ARG001
156167
prefix = (
157168
f"{self.source.alias}."
158169
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
@@ -162,32 +173,47 @@ def ref(self, compiler, connection): # noqa: ARG001
162173
refs, _ = compiler.columns[self.ordinal - 1]
163174
else:
164175
refs = self.refs
165-
return f"${prefix}{refs}"
176+
if not as_path:
177+
prefix = f"${prefix}"
178+
return f"{prefix}{refs}"
166179

167180

168-
def star(self, compiler, connection): # noqa: ARG001
181+
def star(self, compiler, connection, **extra): # noqa: ARG001
169182
return {"$literal": True}
170183

171184

172-
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
173-
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
185+
def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
186+
expr = self.query.as_mql(
187+
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False
188+
)
189+
if as_path:
190+
return {"$expr": expr}
191+
return expr
174192

175193

176-
def exists(self, compiler, connection, get_wrapping_pipeline=None):
194+
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
177195
try:
178-
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
196+
lhs_mql = subquery(
197+
self,
198+
compiler,
199+
connection,
200+
get_wrapping_pipeline=get_wrapping_pipeline,
201+
as_path=as_path,
202+
)
179203
except EmptyResultSet:
180204
return Value(False).as_mql(compiler, connection)
181-
return connection.mongo_operators["isnull"](lhs_mql, False)
205+
if as_path:
206+
return {"$expr": connection.mongo_operators_match["isnull"](lhs_mql, False)}
207+
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
182208

183209

184-
def when(self, compiler, connection):
185-
return self.condition.as_mql(compiler, connection)
210+
def when(self, compiler, connection, as_path=False):
211+
return self.condition.as_mql(compiler, connection, as_path=as_path)
186212

187213

188-
def value(self, compiler, connection): # noqa: ARG001
214+
def value(self, compiler, connection, as_path=False): # noqa: ARG001
189215
value = self.value
190-
if isinstance(value, (list, int)):
216+
if isinstance(value, (list, int)) and not as_path:
191217
# Wrap lists & numbers in $literal to prevent ambiguity when Value
192218
# appears in $project.
193219
return {"$literal": value}
@@ -209,6 +235,36 @@ def value(self, compiler, connection): # noqa: ARG001
209235
return value
210236

211237

238+
@staticmethod
239+
def _is_constant_value(value):
240+
if isinstance(value, list | Array):
241+
iterable = value.get_source_expressions() if isinstance(value, Array) else value
242+
return all(_is_constant_value(e) for e in iterable)
243+
if is_direct_value(value):
244+
return True
245+
return isinstance(value, Func | Value) and not (
246+
value.contains_aggregate
247+
or value.contains_over_clause
248+
or value.contains_column_references
249+
or value.contains_subquery
250+
)
251+
252+
253+
@staticmethod
254+
def _is_simple_column(lhs):
255+
while isinstance(lhs, KeyTransform):
256+
if "." in getattr(lhs, "key_name", ""):
257+
return False
258+
lhs = lhs.lhs
259+
col = lhs.source if isinstance(lhs, Ref) else lhs
260+
# Foreign columns from parent cannot be addressed as single match
261+
return isinstance(col, Col) and col.alias is not None
262+
263+
264+
def _is_simple_expression(self):
265+
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
266+
267+
212268
def register_expressions():
213269
Case.as_mql = case
214270
Col.as_mql = col
@@ -227,3 +283,6 @@ def register_expressions():
227283
Subquery.as_mql = subquery
228284
When.as_mql = when
229285
Value.as_mql = value
286+
BaseExpression.is_simple_expression = _is_simple_expression
287+
BaseExpression.is_simple_column = _is_simple_column
288+
BaseExpression.is_constant_value = _is_constant_value

django_mongodb_backend/expressions/search.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -933,10 +933,12 @@ def __str__(self):
933933
def __repr__(self):
934934
return f"SearchText({self.lhs}, {self.rhs})"
935935

936-
def as_mql(self, compiler, connection):
937-
lhs_mql = process_lhs(self, compiler, connection)
938-
value = process_rhs(self, compiler, connection)
939-
return {"$gte": [lhs_mql, value]}
936+
def as_mql(self, compiler, connection, as_path=False):
937+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
938+
value = process_rhs(self, compiler, connection, as_path=as_path)
939+
if as_path:
940+
return {lhs_mql: {"$gte": value}}
941+
return {"$expr": {"$gte": [lhs_mql, value]}}
940942

941943

942944
CharField.register_lookup(SearchTextLookup)

django_mongodb_backend/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures):
9090
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
9191
# GenericRelation.value_to_string() assumes integer pk.
9292
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
93-
# icontains doesn't work on ArrayField:
94-
# Unsupported conversion from array to string in $convert
95-
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
9693
# ArrayField's contained_by lookup crashes with Exists: "both operands "
9794
# of $setIsSubset must be arrays. Second argument is of type: null"
9895
# https://jira.mongodb.org/browse/SERVER-99186

0 commit comments

Comments
 (0)