Skip to content

Commit 57ca9fa

Browse files
committed
Functional approach solution
1 parent 49a14d9 commit 57ca9fa

File tree

17 files changed

+571
-297
lines changed

17 files changed

+571
-297
lines changed

django_mongodb_backend/base.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .features import DatabaseFeatures
2121
from .introspection import DatabaseIntrospection
2222
from .operations import DatabaseOperations
23-
from .query_utils import regex_match
23+
from .query_utils import regex_expr, regex_match
2424
from .schema import DatabaseSchemaEditor
2525
from .utils import OperationDebugWrapper
2626
from .validation import DatabaseValidation
@@ -108,7 +108,12 @@ def _isnull_operator(a, b):
108108
}
109109
return is_null if b else {"$not": is_null}
110110

111-
mongo_operators = {
111+
def _isnull_operator_match(a, b):
112+
if b:
113+
return {"$or": [{a: {"$exists": False}}, {a: None}]}
114+
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}
115+
116+
mongo_operators_expr = {
112117
"exact": lambda a, b: {"$eq": [a, b]},
113118
"gt": lambda a, b: {"$gt": [a, b]},
114119
"gte": lambda a, b: {"$gte": [a, b]},
@@ -118,19 +123,56 @@ def _isnull_operator(a, b):
118123
"lte": lambda a, b: {
119124
"$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]
120125
},
121-
"in": lambda a, b: {"$in": [a, b]},
126+
"in": lambda a, b: {"$in": (a, b)},
122127
"isnull": _isnull_operator,
123128
"range": lambda a, b: {
124129
"$and": [
125130
{"$or": [DatabaseWrapper._isnull_operator(b[0], True), {"$gte": [a, b[0]]}]},
126131
{"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]},
127132
]
128133
},
129-
"iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True),
130-
"startswith": lambda a, b: regex_match(a, ("^", b)),
131-
"istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True),
132-
"endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})),
133-
"iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True),
134+
"iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True),
135+
"startswith": lambda a, b: regex_expr(a, ("^", b)),
136+
"istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True),
137+
"endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})),
138+
"iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True),
139+
"contains": lambda a, b: regex_expr(a, b),
140+
"icontains": lambda a, b: regex_expr(a, b, insensitive=True),
141+
"regex": lambda a, b: regex_expr(a, b),
142+
"iregex": lambda a, b: regex_expr(a, b, insensitive=True),
143+
}
144+
145+
def range_match(a, b):
146+
## TODO: MAKE A TEST TO TEST WHEN BOTH ENDS ARE NONE. WHAT SHALL I RETURN?
147+
conditions = []
148+
if b[0] is not None:
149+
conditions.append({a: {"$gte": b[0]}})
150+
if b[1] is not None:
151+
conditions.append({a: {"$lte": b[1]}})
152+
if not conditions:
153+
return {"$literal": True}
154+
return {"$and": conditions}
155+
156+
mongo_operators_match = {
157+
"exact": lambda a, b: {a: b},
158+
"gt": lambda a, b: {a: {"$gt": b}},
159+
"gte": lambda a, b: {a: {"$gte": b}},
160+
# MongoDB considers null less than zero. Exclude null values to match
161+
# SQL behavior.
162+
"lt": lambda a, b: {
163+
"$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
164+
},
165+
"lte": lambda a, b: {
166+
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
167+
},
168+
"in": lambda a, b: {a: {"$in": tuple(b)}},
169+
"isnull": _isnull_operator_match,
170+
"range": range_match,
171+
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),
172+
"startswith": lambda a, b: regex_match(a, f"^{b}"),
173+
"istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True),
174+
"endswith": lambda a, b: regex_match(a, f"{b}$"),
175+
"iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True),
134176
"contains": lambda a, b: regex_match(a, b),
135177
"icontains": lambda a, b: regex_match(a, b, insensitive=True),
136178
"regex": lambda a, b: regex_match(a, b),

django_mongodb_backend/compiler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,14 @@ def pre_sql_setup(self, with_col_aliases=False):
327327
pipeline = self._build_aggregation_pipeline(ids, group)
328328
if self.having:
329329
having = self.having.replace_expressions(all_replacements).as_mql(
330-
self, self.connection
330+
self, self.connection, as_path=True
331331
)
332332
# Add HAVING subqueries.
333333
for query in self.subqueries or ():
334334
pipeline.extend(query.get_pipeline())
335335
# Remove the added subqueries.
336336
self.subqueries = []
337-
pipeline.append({"$match": {"$expr": having}})
337+
pipeline.append({"$match": having})
338338
self.aggregation_pipeline = pipeline
339339
self.annotations = {
340340
target: expr.replace_expressions(all_replacements)
@@ -481,11 +481,11 @@ def build_query(self, columns=None):
481481
query.lookup_pipeline = self.get_lookup_pipeline()
482482
where = self.get_where()
483483
try:
484-
expr = where.as_mql(self, self.connection) if where else {}
484+
expr = where.as_mql(self, self.connection, as_path=True) if where else {}
485485
except FullResultSet:
486486
query.match_mql = {}
487487
else:
488-
query.match_mql = {"$expr": expr}
488+
query.match_mql = expr
489489
if extra_fields:
490490
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
491491
query.subqueries = self.subqueries
@@ -643,7 +643,9 @@ def get_combinator_queries(self):
643643
for alias, expr in self.columns:
644644
# Unfold foreign fields.
645645
if isinstance(expr, Col) and expr.alias != self.collection_name:
646-
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
646+
ids[expr.alias][expr.target.column] = expr.as_mql(
647+
self, self.connection, as_path=False
648+
)
647649
else:
648650
ids[alias] = f"${alias}"
649651
# Convert defaultdict to dict so it doesn't appear as
@@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
707709
# For brevity/simplicity, project {"field_name": 1}
708710
# instead of {"field_name": "$field_name"}.
709711
if isinstance(expr, Col) and name == expr.target.column and not force_expression
710-
else expr.as_mql(self, self.connection)
712+
else expr.as_mql(self, self.connection, as_path=False)
711713
)
712714
except EmptyResultSet:
713715
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
714716
value = (
715717
False if empty_result_set_value is NotImplemented else empty_result_set_value
716718
)
717-
fields[collection][name] = Value(value).as_mql(self, self.connection)
719+
fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False)
718720
except FullResultSet:
719-
fields[collection][name] = Value(True).as_mql(self, self.connection)
721+
fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False)
720722
# Annotations (stored in None) and the main collection's fields
721723
# should appear in the top-level of the fields dict.
722724
fields.update(fields.pop(None, {}))

django_mongodb_backend/expressions/builtins.py

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from django.core.exceptions import EmptyResultSet, FullResultSet
77
from django.db import NotSupportedError
88
from django.db.models.expressions import (
9+
BaseExpression,
910
Case,
1011
Col,
1112
ColPairs,
1213
CombinedExpression,
1314
Exists,
1415
ExpressionList,
1516
ExpressionWrapper,
17+
Func,
1618
NegatedExpression,
1719
OrderBy,
1820
RawSQL,
@@ -23,17 +25,19 @@
2325
Value,
2426
When,
2527
)
28+
from django.db.models.fields.json import KeyTransform
2629
from django.db.models.sql import Query
2730

28-
from django_mongodb_backend.query_utils import process_lhs
31+
from django_mongodb_backend.fields.array import Array
32+
from django_mongodb_backend.query_utils import is_direct_value, process_lhs
2933

3034

31-
def case(self, compiler, connection):
35+
def case(self, compiler, connection, as_path=False):
3236
case_parts = []
3337
for case in self.cases:
3438
case_mql = {}
3539
try:
36-
case_mql["case"] = case.as_mql(compiler, connection)
40+
case_mql["case"] = case.as_mql(compiler, connection, as_path=False)
3741
except EmptyResultSet:
3842
continue
3943
except FullResultSet:
@@ -45,12 +49,16 @@ def case(self, compiler, connection):
4549
default_mql = self.default.as_mql(compiler, connection)
4650
if not case_parts:
4751
return default_mql
48-
return {
52+
expr = {
4953
"$switch": {
5054
"branches": case_parts,
5155
"default": default_mql,
5256
}
5357
}
58+
if as_path:
59+
return {"$expr": expr}
60+
61+
return expr
5462

5563

5664
def col(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -76,34 +84,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7684
return f"{prefix}{self.target.column}"
7785

7886

79-
def col_pairs(self, compiler, connection):
87+
def col_pairs(self, compiler, connection, as_path=False):
8088
cols = self.get_cols()
8189
if len(cols) > 1:
8290
raise NotSupportedError("ColPairs is not supported.")
83-
return cols[0].as_mql(compiler, connection)
91+
return cols[0].as_mql(compiler, connection, as_path=as_path)
8492

8593

86-
def combined_expression(self, compiler, connection):
94+
def combined_expression(self, compiler, connection, as_path=False):
8795
expressions = [
88-
self.lhs.as_mql(compiler, connection),
89-
self.rhs.as_mql(compiler, connection),
96+
self.lhs.as_mql(compiler, connection, as_path=as_path),
97+
self.rhs.as_mql(compiler, connection, as_path=as_path),
9098
]
9199
return connection.ops.combine_expression(self.connector, expressions)
92100

93101

94-
def expression_wrapper(self, compiler, connection):
95-
return self.expression.as_mql(compiler, connection)
102+
def expression_wrapper(self, compiler, connection, as_path=False):
103+
return self.expression.as_mql(compiler, connection, as_path=as_path)
96104

97105

98-
def negated_expression(self, compiler, connection):
99-
return {"$not": expression_wrapper(self, compiler, connection)}
106+
def negated_expression(self, compiler, connection, as_path=False):
107+
return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)}
100108

101109

102110
def order_by(self, compiler, connection):
103111
return self.expression.as_mql(compiler, connection)
104112

105113

106-
def query(self, compiler, connection, get_wrapping_pipeline=None):
114+
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
107115
subquery_compiler = self.get_compiler(connection=connection)
108116
subquery_compiler.pre_sql_setup(with_col_aliases=False)
109117
field_name, expr = subquery_compiler.columns[0]
@@ -145,14 +153,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
145153
# Erase project_fields since the required value is projected above.
146154
subquery.project_fields = None
147155
compiler.subqueries.append(subquery)
156+
if as_path:
157+
return f"{table_output}.{field_name}"
148158
return f"${table_output}.{field_name}"
149159

150160

151161
def raw_sql(self, compiler, connection): # noqa: ARG001
152162
raise NotSupportedError("RawSQL is not supported on MongoDB.")
153163

154164

155-
def ref(self, compiler, connection): # noqa: ARG001
165+
def ref(self, compiler, connection, as_path=False): # noqa: ARG001
156166
prefix = (
157167
f"{self.source.alias}."
158168
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
@@ -162,32 +172,47 @@ def ref(self, compiler, connection): # noqa: ARG001
162172
refs, _ = compiler.columns[self.ordinal - 1]
163173
else:
164174
refs = self.refs
165-
return f"${prefix}{refs}"
175+
if not as_path:
176+
prefix = f"${prefix}"
177+
return f"{prefix}{refs}"
166178

167179

168-
def star(self, compiler, connection): # noqa: ARG001
180+
def star(self, compiler, connection, **extra): # noqa: ARG001
169181
return {"$literal": True}
170182

171183

172-
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
173-
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
184+
def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
185+
expr = self.query.as_mql(
186+
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False
187+
)
188+
if as_path:
189+
return {"$expr": expr}
190+
return expr
174191

175192

176-
def exists(self, compiler, connection, get_wrapping_pipeline=None):
193+
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
177194
try:
178-
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
195+
lhs_mql = subquery(
196+
self,
197+
compiler,
198+
connection,
199+
get_wrapping_pipeline=get_wrapping_pipeline,
200+
as_path=as_path,
201+
)
179202
except EmptyResultSet:
180203
return Value(False).as_mql(compiler, connection)
181-
return connection.mongo_operators["isnull"](lhs_mql, False)
204+
if as_path:
205+
return {"$expr": connection.mongo_operators_match["isnull"](lhs_mql, False)}
206+
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
182207

183208

184-
def when(self, compiler, connection):
185-
return self.condition.as_mql(compiler, connection)
209+
def when(self, compiler, connection, as_path=False):
210+
return self.condition.as_mql(compiler, connection, as_path=as_path)
186211

187212

188-
def value(self, compiler, connection): # noqa: ARG001
213+
def value(self, compiler, connection, as_path=False): # noqa: ARG001
189214
value = self.value
190-
if isinstance(value, (list, int)):
215+
if isinstance(value, (list, int)) and not as_path:
191216
# Wrap lists & numbers in $literal to prevent ambiguity when Value
192217
# appears in $project.
193218
return {"$literal": value}
@@ -209,6 +234,36 @@ def value(self, compiler, connection): # noqa: ARG001
209234
return value
210235

211236

237+
@staticmethod
238+
def _is_constant_value(value):
239+
if isinstance(value, list | Array):
240+
iterable = value.get_source_expressions() if isinstance(value, Array) else value
241+
return all(_is_constant_value(e) for e in iterable)
242+
if is_direct_value(value):
243+
return True
244+
return isinstance(value, Func | Value) and not (
245+
value.contains_aggregate
246+
or value.contains_over_clause
247+
or value.contains_column_references
248+
or value.contains_subquery
249+
)
250+
251+
252+
@staticmethod
253+
def _is_simple_column(lhs):
254+
while isinstance(lhs, KeyTransform):
255+
if "." in getattr(lhs, "key_name", ""):
256+
return False
257+
lhs = lhs.lhs
258+
col = lhs.source if isinstance(lhs, Ref) else lhs
259+
# Foreign columns from parent cannot be addressed as single match
260+
return isinstance(col, Col) and col.alias is not None
261+
262+
263+
def _is_simple_expression(self):
264+
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
265+
266+
212267
def register_expressions():
213268
Case.as_mql = case
214269
Col.as_mql = col
@@ -227,3 +282,6 @@ def register_expressions():
227282
Subquery.as_mql = subquery
228283
When.as_mql = when
229284
Value.as_mql = value
285+
BaseExpression.is_simple_expression = _is_simple_expression
286+
BaseExpression.is_simple_column = _is_simple_column
287+
BaseExpression.is_constant_value = _is_constant_value

django_mongodb_backend/expressions/search.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -933,10 +933,12 @@ def __str__(self):
933933
def __repr__(self):
934934
return f"SearchText({self.lhs}, {self.rhs})"
935935

936-
def as_mql(self, compiler, connection):
937-
lhs_mql = process_lhs(self, compiler, connection)
938-
value = process_rhs(self, compiler, connection)
939-
return {"$gte": [lhs_mql, value]}
936+
def as_mql(self, compiler, connection, as_path=False):
937+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
938+
value = process_rhs(self, compiler, connection, as_path=as_path)
939+
if as_path:
940+
return {lhs_mql: {"$gte": value}}
941+
return {"$expr": {"$gte": [lhs_mql, value]}}
940942

941943

942944
CharField.register_lookup(SearchTextLookup)

django_mongodb_backend/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures):
9090
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
9191
# GenericRelation.value_to_string() assumes integer pk.
9292
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
93-
# icontains doesn't work on ArrayField:
94-
# Unsupported conversion from array to string in $convert
95-
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
9693
# ArrayField's contained_by lookup crashes with Exists: "both operands "
9794
# of $setIsSubset must be arrays. Second argument is of type: null"
9895
# https://jira.mongodb.org/browse/SERVER-99186

0 commit comments

Comments
 (0)