Skip to content

Commit d9ea0e0

Browse files
committed
Fix refactor
1 parent 4130665 commit d9ea0e0

File tree

2 files changed

+74
-81
lines changed

2 files changed

+74
-81
lines changed

django_mongodb_backend/functions.py

Lines changed: 55 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
}
6666

6767

68-
def cast(self, compiler, connection, as_path=False):
68+
def cast(self, compiler, connection):
6969
output_type = connection.data_types[self.output_field.get_internal_type()]
7070
lhs_mql = process_lhs(self, compiler, connection, as_path=False)[0]
7171
if max_length := self.output_field.max_length:
@@ -78,27 +78,21 @@ def cast(self, compiler, connection, as_path=False):
7878
if decimal_places := getattr(self.output_field, "decimal_places", None):
7979
lhs_mql = {"$trunc": [lhs_mql, decimal_places]}
8080

81-
if as_path:
82-
return {"$expr": lhs_mql}
8381
return lhs_mql
8482

8583

86-
def concat(self, compiler, connection, as_path=False):
87-
return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=as_path)
84+
def concat(self, compiler, connection):
85+
return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=False)
8886

8987

90-
def concat_pair(self, compiler, connection, as_path=False):
88+
def concat_pair(self, compiler, connection):
9189
# null on either side results in null for expression, wrap with coalesce.
9290
coalesced = self.coalesce()
93-
if as_path:
94-
return {"$expr": super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False)}
9591
return super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False)
9692

9793

98-
def cot(self, compiler, connection, as_path=False):
94+
def cot(self, compiler, connection):
9995
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
100-
if as_path:
101-
return {"$expr": {"$divide": [1, {"$tan": lhs_mql}]}}
10296
return {"$divide": [1, {"$tan": lhs_mql}]}
10397

10498

@@ -109,10 +103,7 @@ def extract(self, compiler, connection, as_path=False):
109103
raise NotSupportedError(f"{self.__class__.__name__} is not supported.")
110104
if timezone := self.get_tzname():
111105
lhs_mql = {"date": lhs_mql, "timezone": timezone}
112-
expr = {f"${operator}": lhs_mql}
113-
if as_path:
114-
return {"$expr": expr}
115-
return expr
106+
return {f"${operator}": lhs_mql}
116107

117108

118109
def func(self, compiler, connection, as_path=False):
@@ -137,73 +128,57 @@ def func_expr(self, compiler, connection):
137128
return {f"${operator}": lhs_mql}
138129

139130

140-
def left(self, compiler, connection, as_path=False):
141-
return self.get_substr().as_mql(compiler, connection, as_path=as_path)
131+
def left(self, compiler, connection):
132+
return self.get_substr().as_mql(compiler, connection, as_path=False)
142133

143134

144-
def length(self, compiler, connection, as_path=False):
135+
def length(self, compiler, connection):
145136
# Check for null first since $strLenCP only accepts strings.
146137
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
147-
expr = {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}}
148-
if as_path:
149-
return {"$expr": expr}
150-
return expr
138+
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}}
151139

152140

153-
def log(self, compiler, connection, as_path=False):
141+
def log(self, compiler, connection):
154142
# This function is usually log(base, num) but on MongoDB it's log(num, base).
155143
clone = self.copy()
156144
clone.set_source_expressions(self.get_source_expressions()[::-1])
157-
return func(clone, compiler, connection, as_path=as_path)
145+
return func(clone, compiler, connection, as_path=False)
158146

159147

160-
def now(self, compiler, connection, as_path=False): # noqa: ARG001
148+
def now(self, compiler, connection): # noqa: ARG001
161149
return "$$NOW"
162150

163151

164-
def null_if(self, compiler, connection, as_path=False):
152+
def null_if(self, compiler, connection):
165153
"""Return None if expr1==expr2 else expr1."""
166154
expr1, expr2 = (
167155
expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions()
168156
)
169-
expr = {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}}
170-
if as_path:
171-
return {"$expr": expr}
172-
return expr
157+
return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}}
173158

174159

175160
def preserve_null(operator):
176161
# If the argument is null, the function should return null, not
177162
# $toLower/Upper's behavior of returning an empty string.
178-
def wrapped(self, compiler, connection, as_path=False):
179-
if as_path and self.is_constant_value(self.lhs):
180-
if self.lhs is None:
181-
return None
182-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
183-
return lhs_mql.upper()
163+
def wrapped(self, compiler, connection):
184164
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
185-
inner_expression = {
165+
return {
186166
"$cond": {
187167
"if": connection.mongo_operators_expr["isnull"](lhs_mql, True),
188168
"then": None,
189169
"else": {f"${operator}": lhs_mql},
190170
}
191171
}
192-
# we need to wrap this, because it will be handled in a no expression tree.
193-
# needed in MongoDB 6.
194-
if as_path:
195-
return {"$expr": inner_expression}
196-
return inner_expression
197172

198173
return wrapped
199174

200175

201-
def replace(self, compiler, connection, as_path=False):
202-
expression, text, replacement = process_lhs(self, compiler, connection, as_path=as_path)
176+
def replace(self, compiler, connection):
177+
expression, text, replacement = process_lhs(self, compiler, connection, as_path=False)
203178
return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}}
204179

205180

206-
def round_(self, compiler, connection, as_path=False): # noqa: ARG001
181+
def round_(self, compiler, connection):
207182
# Round needs its own function because it's a special case that inherits
208183
# from Transform but has two arguments.
209184
return {
@@ -214,13 +189,13 @@ def round_(self, compiler, connection, as_path=False): # noqa: ARG001
214189
}
215190

216191

217-
def str_index(self, compiler, connection, as_path=False): # noqa: ARG001
218-
lhs = process_lhs(self, compiler, connection)
192+
def str_index(self, compiler, connection):
193+
lhs = process_lhs(self, compiler, connection, as_path=False)
219194
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
220195
return {"$add": [{"$indexOfCP": lhs}, 1]}
221196

222197

223-
def substr(self, compiler, connection, as_path=False): # noqa: ARG001
198+
def substr(self, compiler, connection):
224199
lhs = process_lhs(self, compiler, connection)
225200
# The starting index is zero-indexed on MongoDB rather than one-indexed.
226201
lhs[1] = {"$add": [lhs[1], -1]}
@@ -232,14 +207,14 @@ def substr(self, compiler, connection, as_path=False): # noqa: ARG001
232207

233208

234209
def trim(operator):
235-
def wrapped(self, compiler, connection, as_path=False): # noqa: ARG001
210+
def wrapped(self, compiler, connection):
236211
lhs = process_lhs(self, compiler, connection)
237212
return {f"${operator}": {"input": lhs}}
238213

239214
return wrapped
240215

241216

242-
def trunc(self, compiler, connection, as_path=False): # noqa: ARG001
217+
def trunc(self, compiler, connection):
243218
lhs_mql = process_lhs(self, compiler, connection)
244219
lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"}
245220
if timezone := self.get_tzname():
@@ -298,7 +273,7 @@ def trunc_date(self, compiler, connection, **extra): # noqa: ARG001
298273
}
299274

300275

301-
def trunc_time(self, compiler, connection, as_path=False): # noqa: ARG001
276+
def trunc_time(self, compiler, connection):
302277
tzname = self.get_tzname()
303278
if tzname and tzname != "UTC":
304279
raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.")
@@ -318,30 +293,35 @@ def trunc_time(self, compiler, connection, as_path=False): # noqa: ARG001
318293
}
319294

320295

296+
def is_simple_expression(self): # noqa: ARG001
297+
return False
298+
299+
321300
def register_functions():
322-
Cast.as_mql = cast
323-
Concat.as_mql = concat
324-
ConcatPair.as_mql = concat_pair
325-
Cot.as_mql = cot
326-
Extract.as_mql = extract
301+
Cast.as_mql_expr = cast
302+
Concat.as_mql_expr = concat
303+
ConcatPair.as_mql_expr = concat_pair
304+
Cot.as_mql_expr = cot
305+
Extract.as_mql_expr = extract
327306
Func.as_mql_path = func_path
328307
Func.as_mql_expr = func_expr
329-
JSONArray.as_mql = process_lhs
330-
Left.as_mql = left
331-
Length.as_mql = length
332-
Log.as_mql = log
333-
Lower.as_mql = preserve_null("toLower")
334-
LTrim.as_mql = trim("ltrim")
335-
Now.as_mql = now
336-
NullIf.as_mql = null_if
337-
Replace.as_mql = replace
338-
Round.as_mql = round_
339-
RTrim.as_mql = trim("rtrim")
340-
StrIndex.as_mql = str_index
341-
Substr.as_mql = substr
342-
Trim.as_mql = trim("trim")
343-
TruncBase.as_mql = trunc
344-
TruncBase.convert_value = trunc_convert_value
345-
TruncDate.as_mql = trunc_date
346-
TruncTime.as_mql = trunc_time
347-
Upper.as_mql = preserve_null("toUpper")
308+
JSONArray.as_mql_expr = process_lhs
309+
Left.as_mql_expr = left
310+
Length.as_mql_expr = length
311+
Log.as_mql_expr = log
312+
Lower.as_mql_expr = preserve_null("toLower")
313+
LTrim.as_mql_expr = trim("ltrim")
314+
Now.as_mql_expr = now
315+
NullIf.as_mql_expr = null_if
316+
Replace.as_mql_expr = replace
317+
Round.as_mql_expr = round_
318+
RTrim.as_mql_expr = trim("rtrim")
319+
StrIndex.as_mql_expr = str_index
320+
Substr.as_mql_expr = substr
321+
Trim.as_mql_expr = trim("trim")
322+
TruncBase.as_mql_expr = trunc
323+
TruncBase.convert_value_expr = trunc_convert_value
324+
TruncDate.as_mql_expr = trunc_date
325+
TruncTime.as_mql_expr = trunc_time
326+
Upper.as_mql_expr = preserve_null("toUpper")
327+
Func.is_simple_expression = is_simple_expression

django_mongodb_backend/query_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from django.core.exceptions import FullResultSet
22
from django.db.models.aggregates import Aggregate
3-
from django.db.models.expressions import Col, Func, Ref, Value
3+
from django.db.models.expressions import Col, CombinedExpression, Ref, Value
44
from django.db.models.fields.json import KeyTransform
5+
from django.db.models.sql.query import Query
56

67

78
def is_direct_value(node):
@@ -64,15 +65,27 @@ def regex_match(field, regex, insensitive=False):
6465

6566

6667
def is_constant_value(value):
68+
# This should be remove once CombinedExpression which are cosntant value
69+
# are resolved
70+
if isinstance(value, CombinedExpression):
71+
return False
6772
if isinstance(value, list):
6873
return all(map(is_constant_value, value))
6974
if is_direct_value(value):
7075
return True
71-
return isinstance(value, Func | Value) and not (
72-
value.contains_aggregate
73-
or value.contains_over_clause
74-
or value.contains_column_references
75-
or value.contains_subquery
76+
# This should be remove when the same thing above.
77+
if hasattr(value, "get_source_expressions"):
78+
simple_sub_expressions = all(map(is_constant_value, value.get_source_expressions()))
79+
return (
80+
simple_sub_expressions
81+
and isinstance(value, Value)
82+
and not (
83+
isinstance(value, Query)
84+
or value.contains_aggregate
85+
or value.contains_over_clause
86+
or value.contains_column_references
87+
or value.contains_subquery
88+
)
7689
)
7790

7891

0 commit comments

Comments
 (0)