Skip to content

Commit 8572489

Browse files
committed
Refactor: replace es_path for as_expr.
1 parent 78cde50 commit 8572489

File tree

17 files changed

+225
-213
lines changed

17 files changed

+225
-213
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def aggregate(
2424
node.set_source_expressions([Case(condition), *source_expressions[1:]])
2525
else:
2626
node = self
27-
lhs_mql = process_lhs(node, compiler, connection)
27+
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
2828
if resolve_inner_expression:
2929
return lhs_mql
3030
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
@@ -46,9 +46,9 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
4646
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
4747
)
4848
node.set_source_expressions([Case(condition), *source_expressions[1:]])
49-
inner_expression = process_lhs(node, compiler, connection)
49+
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
5050
else:
51-
lhs_mql = process_lhs(self, compiler, connection)
51+
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
5252
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
5353
inner_expression = {
5454
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
@@ -58,7 +58,7 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
5858
return {"$sum": inner_expression}
5959
# If distinct=True or resolve_inner_expression=False, sum the size of the
6060
# set.
61-
lhs_mql = process_lhs(self, compiler, connection)
61+
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
6262
# None shouldn't be counted, so subtract 1 if it's present.
6363
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
6464
return {"$add": [{"$size": lhs_mql}, exits_null]}

django_mongodb_backend/base.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,38 +98,35 @@ class DatabaseWrapper(BaseDatabaseWrapper):
9898
}
9999
_connection_pools = {}
100100

101-
def _isnull_operator(a, b):
101+
def _isnull_operator_expr(field, null):
102102
is_null = {
103103
"$or": [
104104
# The path does not exist (i.e. is "missing")
105-
{"$eq": [{"$type": a}, "missing"]},
105+
{"$eq": [{"$type": field}, "missing"]},
106106
# or the value is None.
107-
{"$eq": [a, None]},
107+
{"$eq": [field, None]},
108108
]
109109
}
110-
return is_null if b else {"$not": is_null}
111-
112-
def _isnull_operator_match(a, b):
113-
if b:
114-
return {"$or": [{a: {"$exists": False}}, {a: None}]}
115-
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}
110+
return is_null if null else {"$not": is_null}
116111

117112
mongo_expr_operators = {
118113
"exact": lambda a, b: {"$eq": [a, b]},
119114
"gt": lambda a, b: {"$gt": [a, b]},
120115
"gte": lambda a, b: {"$gte": [a, b]},
121116
# MongoDB considers null less than zero. Exclude null values to match
122117
# SQL behavior.
123-
"lt": lambda a, b: {"$and": [{"$lt": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]},
118+
"lt": lambda a, b: {
119+
"$and": [{"$lt": [a, b]}, DatabaseWrapper._isnull_operator_expr(a, False)]
120+
},
124121
"lte": lambda a, b: {
125-
"$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]
122+
"$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator_expr(a, False)]
126123
},
127124
"in": lambda a, b: {"$in": (a, b)},
128-
"isnull": _isnull_operator,
125+
"isnull": _isnull_operator_expr,
129126
"range": lambda a, b: {
130127
"$and": [
131-
{"$or": [DatabaseWrapper._isnull_operator(b[0], True), {"$gte": [a, b[0]]}]},
132-
{"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]},
128+
{"$or": [DatabaseWrapper._isnull_operator_expr(b[0], True), {"$gte": [a, b[0]]}]},
129+
{"$or": [DatabaseWrapper._isnull_operator_expr(b[1], True), {"$lte": [a, b[1]]}]},
133130
]
134131
},
135132
"iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True),
@@ -161,6 +158,11 @@ def range_match(a, b):
161158
raise EmptyResultSet
162159
return {"$and": conditions}
163160

161+
def _isnull_operator_match(field, null):
162+
if null:
163+
return {"$or": [{field: {"$exists": False}}, {field: None}]}
164+
return {"$and": [{field: {"$exists": True}}, {field: {"$ne": None}}]}
165+
164166
mongo_operators = {
165167
"exact": lambda a, b: {a: b},
166168
"gt": lambda a, b: {a: {"$gt": b}},

django_mongodb_backend/compiler.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,14 @@ def _get_replace_expr(self, sub_expr, group, alias):
6969
if getattr(sub_expr, "distinct", False):
7070
# If the expression should return distinct values, use $addToSet to
7171
# deduplicate.
72-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
72+
rhs = sub_expr.as_mql(
73+
self, self.connection, resolve_inner_expression=True, as_expr=True
74+
)
7375
group[alias] = {"$addToSet": rhs}
7476
replacing_expr = sub_expr.copy()
7577
replacing_expr.set_source_expressions([inner_column, None])
7678
else:
77-
group[alias] = sub_expr.as_mql(self, self.connection)
79+
group[alias] = sub_expr.as_mql(self, self.connection, as_expr=True)
7880
replacing_expr = inner_column
7981
# Count must return 0 rather than null.
8082
if isinstance(sub_expr, Count):
@@ -302,9 +304,7 @@ def _compound_searches_queries(self, search_replacements):
302304
search.as_mql(self, self.connection),
303305
{
304306
"$addFields": {
305-
result_col.as_mql(self, self.connection, as_path=True): {
306-
"$meta": score_function
307-
}
307+
result_col.as_mql(self, self.connection): {"$meta": score_function}
308308
}
309309
},
310310
]
@@ -327,7 +327,7 @@ def pre_sql_setup(self, with_col_aliases=False):
327327
pipeline = self._build_aggregation_pipeline(ids, group)
328328
if self.having:
329329
having = self.having.replace_expressions(all_replacements).as_mql(
330-
self, self.connection, as_path=True
330+
self, self.connection
331331
)
332332
# Add HAVING subqueries.
333333
for query in self.subqueries or ():
@@ -481,7 +481,7 @@ def build_query(self, columns=None):
481481
query.lookup_pipeline = self.get_lookup_pipeline()
482482
where = self.get_where()
483483
try:
484-
match = where.as_mql(self, self.connection, as_path=True) if where else {}
484+
match = where.as_mql(self, self.connection) if where else {}
485485
except FullResultSet:
486486
query.match_mql = {}
487487
else:
@@ -643,7 +643,9 @@ def get_combinator_queries(self):
643643
for alias, expr in self.columns:
644644
# Unfold foreign fields.
645645
if isinstance(expr, Col) and expr.alias != self.collection_name:
646-
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
646+
ids[expr.alias][expr.target.column] = expr.as_mql(
647+
self, self.connection, as_expr=True
648+
)
647649
else:
648650
ids[alias] = f"${alias}"
649651
# Convert defaultdict to dict so it doesn't appear as
@@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
707709
# For brevity/simplicity, project {"field_name": 1}
708710
# instead of {"field_name": "$field_name"}.
709711
if isinstance(expr, Col) and name == expr.target.column and not force_expression
710-
else expr.as_mql(self, self.connection)
712+
else expr.as_mql(self, self.connection, as_expr=True)
711713
)
712714
except EmptyResultSet:
713715
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
714716
value = (
715717
False if empty_result_set_value is NotImplemented else empty_result_set_value
716718
)
717-
fields[collection][name] = Value(value).as_mql(self, self.connection)
719+
fields[collection][name] = Value(value).as_mql(self, self.connection, as_expr=True)
718720
except FullResultSet:
719-
fields[collection][name] = Value(True).as_mql(self, self.connection)
721+
fields[collection][name] = Value(True).as_mql(self, self.connection, as_expr=True)
720722
# Annotations (stored in None) and the main collection's fields
721723
# should appear in the top-level of the fields dict.
722724
fields.update(fields.pop(None, {}))
@@ -739,10 +741,10 @@ def _get_ordering(self):
739741
idx = itertools.count(start=1)
740742
for order in self.order_by_objs or []:
741743
if isinstance(order.expression, Col):
742-
field_name = order.as_mql(self, self.connection).removeprefix("$")
744+
field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$")
743745
fields.append((order.expression.target.column, order.expression))
744746
elif isinstance(order.expression, Ref):
745-
field_name = order.as_mql(self, self.connection).removeprefix("$")
747+
field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$")
746748
else:
747749
field_name = f"__order{next(idx)}"
748750
fields.append((field_name, order.expression))
@@ -879,7 +881,7 @@ def execute_sql(self, result_type):
879881
)
880882
prepared = field.get_db_prep_save(value, connection=self.connection)
881883
if hasattr(value, "as_mql"):
882-
prepared = prepared.as_mql(self, self.connection)
884+
prepared = prepared.as_mql(self, self.connection, as_expr=True)
883885
values[field.column] = prepared
884886
try:
885887
criteria = self.build_query().match_mql

django_mongodb_backend/expressions/builtins.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
from decimal import Decimal
3+
from functools import partialmethod
34
from uuid import UUID
45

56
from bson import Decimal128
@@ -29,29 +30,29 @@
2930
from django_mongodb_backend.query_utils import process_lhs
3031

3132

32-
def base_expression(self, compiler, connection, as_path=False, **extra):
33-
if as_path and hasattr(self, "as_mql_path") and getattr(self, "can_use_path", False):
33+
def base_expression(self, compiler, connection, as_expr=False, **extra):
34+
if not as_expr and hasattr(self, "as_mql_path") and getattr(self, "can_use_path", False):
3435
return self.as_mql_path(compiler, connection, **extra)
3536

3637
expr = self.as_mql_expr(compiler, connection, **extra)
37-
return {"$expr": expr} if as_path else expr
38+
return expr if as_expr else {"$expr": expr}
3839

3940

4041
def case(self, compiler, connection):
4142
case_parts = []
4243
for case in self.cases:
4344
case_mql = {}
4445
try:
45-
case_mql["case"] = case.as_mql(compiler, connection)
46+
case_mql["case"] = case.as_mql(compiler, connection, as_expr=True)
4647
except EmptyResultSet:
4748
continue
4849
except FullResultSet:
49-
default_mql = case.result.as_mql(compiler, connection)
50+
default_mql = case.result.as_mql(compiler, connection, as_expr=True)
5051
break
51-
case_mql["then"] = case.result.as_mql(compiler, connection)
52+
case_mql["then"] = case.result.as_mql(compiler, connection, as_expr=True)
5253
case_parts.append(case_mql)
5354
else:
54-
default_mql = self.default.as_mql(compiler, connection)
55+
default_mql = self.default.as_mql(compiler, connection, as_expr=True)
5556
if not case_parts:
5657
return default_mql
5758
return {
@@ -62,7 +63,7 @@ def case(self, compiler, connection):
6263
}
6364

6465

65-
def col(self, compiler, connection, as_path=False): # noqa: ARG001
66+
def col(self, compiler, connection, as_expr=False): # noqa: ARG001
6667
# If the column is part of a subquery and belongs to one of the parent
6768
# queries, it will be stored for reference using $let in a $lookup stage.
6869
# If the query is built with `alias_cols=False`, treat the column as
@@ -80,39 +81,39 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
8081
# Add the column's collection's alias for columns in joined collections.
8182
has_alias = self.alias and self.alias != compiler.collection_name
8283
prefix = f"{self.alias}." if has_alias else ""
83-
if not as_path:
84+
if as_expr:
8485
prefix = f"${prefix}"
8586
return f"{prefix}{self.target.column}"
8687

8788

88-
def col_pairs(self, compiler, connection, as_path=False):
89+
def col_pairs(self, compiler, connection, as_expr=False):
8990
cols = self.get_cols()
9091
if len(cols) > 1:
9192
raise NotSupportedError("ColPairs is not supported.")
92-
return cols[0].as_mql(compiler, connection, as_path=as_path)
93+
return cols[0].as_mql(compiler, connection, as_expr=as_expr)
9394

9495

9596
def combined_expression(self, compiler, connection):
9697
expressions = [
97-
self.lhs.as_mql(compiler, connection),
98-
self.rhs.as_mql(compiler, connection),
98+
self.lhs.as_mql(compiler, connection, as_expr=True),
99+
self.rhs.as_mql(compiler, connection, as_expr=True),
99100
]
100101
return connection.ops.combine_expression(self.connector, expressions)
101102

102103

103104
def expression_wrapper(self, compiler, connection):
104-
return self.expression.as_mql(compiler, connection)
105+
return self.expression.as_mql(compiler, connection, as_expr=True)
105106

106107

107108
def negated_expression(self, compiler, connection):
108109
return {"$not": expression_wrapper(self, compiler, connection)}
109110

110111

111112
def order_by(self, compiler, connection):
112-
return self.expression.as_mql(compiler, connection)
113+
return self.expression.as_mql(compiler, connection, as_expr=True)
113114

114115

115-
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
116+
def query(self, compiler, connection, get_wrapping_pipeline=None, as_expr=False):
116117
subquery_compiler = self.get_compiler(connection=connection)
117118
subquery_compiler.pre_sql_setup(with_col_aliases=False)
118119
field_name, expr = subquery_compiler.columns[0]
@@ -132,7 +133,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False)
132133
"as": table_output,
133134
"from": from_table,
134135
"let": {
135-
compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection)
136+
compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection, as_expr=True)
136137
for col, i in subquery_compiler.column_indices.items()
137138
},
138139
}
@@ -154,16 +155,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False)
154155
# Erase project_fields since the required value is projected above.
155156
subquery.project_fields = None
156157
compiler.subqueries.append(subquery)
157-
if as_path:
158-
return f"{table_output}.{field_name}"
159-
return f"${table_output}.{field_name}"
158+
if as_expr:
159+
return f"${table_output}.{field_name}"
160+
return f"{table_output}.{field_name}"
160161

161162

162163
def raw_sql(self, compiler, connection): # noqa: ARG001
163164
raise NotSupportedError("RawSQL is not supported on MongoDB.")
164165

165166

166-
def ref(self, compiler, connection, as_path=False): # noqa: ARG001
167+
def ref(self, compiler, connection, as_expr=False): # noqa: ARG001
167168
prefix = (
168169
f"{self.source.alias}."
169170
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
@@ -173,7 +174,7 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001
173174
refs, _ = compiler.columns[self.ordinal - 1]
174175
else:
175176
refs = self.refs
176-
if not as_path:
177+
if as_expr:
177178
prefix = f"${prefix}"
178179
return f"{prefix}{refs}"
179180

@@ -187,27 +188,29 @@ def star(self, compiler, connection): # noqa: ARG001
187188
return {"$literal": True}
188189

189190

190-
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
191+
def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_expr=False):
191192
return self.query.as_mql(
192-
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False
193+
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_expr=as_expr
193194
)
194195

195196

196197
def exists(self, compiler, connection, get_wrapping_pipeline=None):
197198
try:
198-
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
199+
lhs_mql = subquery(
200+
self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_expr=True
201+
)
199202
except EmptyResultSet:
200-
return Value(False).as_mql(compiler, connection)
203+
return Value(False).as_mql(compiler, connection, as_expr=True)
201204
return connection.mongo_expr_operators["isnull"](lhs_mql, False)
202205

203206

204207
def when(self, compiler, connection):
205-
return self.condition.as_mql(compiler, connection)
208+
return self.condition.as_mql(compiler, connection, as_expr=True)
206209

207210

208-
def value(self, compiler, connection, as_path=False): # noqa: ARG001
211+
def value(self, compiler, connection, as_expr=False): # noqa: ARG001
209212
value = self.value
210-
if isinstance(value, (list, int)) and not as_path:
213+
if isinstance(value, (list, int)) and as_expr:
211214
# Wrap lists & numbers in $literal to prevent ambiguity when Value
212215
# appears in $project.
213216
return {"$literal": value}
@@ -248,6 +251,7 @@ def register_expressions():
248251
Ref.is_simple_column = ref_is_simple_column
249252
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
250253
Star.as_mql_expr = star
251-
Subquery.as_mql_expr = subquery
254+
Subquery.as_mql_expr = partialmethod(subquery, as_expr=True)
255+
Subquery.as_mql_path = partialmethod(subquery, as_expr=False)
252256
When.as_mql_expr = when
253257
Value.as_mql = value

0 commit comments

Comments
 (0)