Skip to content

Commit b81c822

Browse files
committed
wip
1 parent 140df1e commit b81c822

File tree

5 files changed

+28
-22
lines changed

5 files changed

+28
-22
lines changed

django_mongodb_backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def pre_sql_setup(self, with_col_aliases=False):
334334
pipeline.extend(query.get_pipeline())
335335
# Remove the added subqueries.
336336
self.subqueries = []
337-
pipeline.append({"$match": {"$expr": having}})
337+
pipeline.append({"$match": having})
338338
self.aggregation_pipeline = pipeline
339339
self.annotations = {
340340
target: expr.replace_expressions(all_replacements)
@@ -707,7 +707,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
707707
# For brevity/simplicity, project {"field_name": 1}
708708
# instead of {"field_name": "$field_name"}.
709709
if isinstance(expr, Col) and name == expr.target.column and not force_expression
710-
else expr.as_mql(self, self.connection)
710+
else expr.as_mql(self, self.connection, as_expr=force_expression)
711711
)
712712
except EmptyResultSet:
713713
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)

django_mongodb_backend/expressions/builtins.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from ..query_utils import process_lhs
2929

3030

31-
def case(self, compiler, connection):
31+
# EXTRA IS TOTALLY IGNORED
32+
def case(self, compiler, connection, **extra): # noqa: ARG001
3233
case_parts = []
3334
for case in self.cases:
3435
case_mql = {}
@@ -53,7 +54,7 @@ def case(self, compiler, connection):
5354
}
5455

5556

56-
def col(self, compiler, connection, as_path=False): # noqa: ARG001
57+
def col(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001
5758
# If the column is part of a subquery and belongs to one of the parent
5859
# queries, it will be stored for reference using $let in a $lookup stage.
5960
# If the query is built with `alias_cols=False`, treat the column as
@@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7172
# Add the column's collection's alias for columns in joined collections.
7273
has_alias = self.alias and self.alias != compiler.collection_name
7374
prefix = f"{self.alias}." if has_alias else ""
74-
if not as_path:
75+
if not as_path or as_expr:
7576
prefix = f"${prefix}"
7677
return f"{prefix}{self.target.column}"
7778

@@ -83,16 +84,16 @@ def col_pairs(self, compiler, connection):
8384
return cols[0].as_mql(compiler, connection)
8485

8586

86-
def combined_expression(self, compiler, connection):
87+
def combined_expression(self, compiler, connection, **extra):
8788
expressions = [
88-
self.lhs.as_mql(compiler, connection),
89-
self.rhs.as_mql(compiler, connection),
89+
self.lhs.as_mql(compiler, connection, **extra),
90+
self.rhs.as_mql(compiler, connection, **extra),
9091
]
9192
return connection.ops.combine_expression(self.connector, expressions)
9293

9394

94-
def expression_wrapper(self, compiler, connection):
95-
return self.expression.as_mql(compiler, connection)
95+
def expression_wrapper(self, compiler, connection, **extra):
96+
return self.expression.as_mql(compiler, connection, **extra)
9697

9798

9899
def negated_expression(self, compiler, connection):
@@ -103,7 +104,7 @@ def order_by(self, compiler, connection):
103104
return self.expression.as_mql(compiler, connection)
104105

105106

106-
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
107+
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, as_expr=None):
107108
subquery_compiler = self.get_compiler(connection=connection)
108109
subquery_compiler.pre_sql_setup(with_col_aliases=False)
109110
field_name, expr = subquery_compiler.columns[0]
@@ -145,7 +146,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False)
145146
# Erase project_fields since the required value is projected above.
146147
subquery.project_fields = None
147148
compiler.subqueries.append(subquery)
148-
if as_path:
149+
if as_path and not as_expr:
149150
return f"{table_output}.{field_name}"
150151
return f"${table_output}.{field_name}"
151152

@@ -200,7 +201,7 @@ def when(self, compiler, connection, **extra):
200201
return self.condition.as_mql(compiler, connection, **extra)
201202

202203

203-
def value(self, compiler, connection): # noqa: ARG001
204+
def value(self, compiler, connection, **extra): # noqa: ARG001
204205
value = self.value
205206
if isinstance(value, (list, int)):
206207
# Wrap lists & numbers in $literal to prevent ambiguity when Value

django_mongodb_backend/functions.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@
6565
}
6666

6767

68-
def cast(self, compiler, connection):
68+
# TODO: ALL THOSE FUNCTION MAY CHECK AS_EXPR OR AS_PATH=FALSE. JUST NEED TO REVIEW ALL THE
69+
# TEST THAT HAVE THOSE OPERATOR.
70+
71+
72+
def cast(self, compiler, connection, **extra): # noqa: ARG001
6973
output_type = connection.data_types[self.output_field.get_internal_type()]
7074
lhs_mql = process_lhs(self, compiler, connection)[0]
7175
if max_length := self.output_field.max_length:
@@ -95,7 +99,7 @@ def cot(self, compiler, connection):
9599
return {"$divide": [1, {"$tan": lhs_mql}]}
96100

97101

98-
def extract(self, compiler, connection):
102+
def extract(self, compiler, connection, **extra): # noqa: ARG001
99103
lhs_mql = process_lhs(self, compiler, connection)
100104
operator = EXTRACT_OPERATORS.get(self.lookup_name)
101105
if operator is None:
@@ -105,7 +109,7 @@ def extract(self, compiler, connection):
105109
return {f"${operator}": lhs_mql}
106110

107111

108-
def func(self, compiler, connection):
112+
def func(self, compiler, connection, **extra): # noqa: ARG001
109113
lhs_mql = process_lhs(self, compiler, connection)
110114
if self.function is None:
111115
raise NotSupportedError(f"{self} may need an as_mql() method.")
@@ -117,7 +121,7 @@ def left(self, compiler, connection):
117121
return self.get_substr().as_mql(compiler, connection)
118122

119123

120-
def length(self, compiler, connection):
124+
def length(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001
121125
# Check for null first since $strLenCP only accepts strings.
122126
lhs_mql = process_lhs(self, compiler, connection)
123127
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}}
@@ -194,7 +198,7 @@ def wrapped(self, compiler, connection):
194198
return wrapped
195199

196200

197-
def trunc(self, compiler, connection):
201+
def trunc(self, compiler, connection, **extra): # noqa: ARG001
198202
lhs_mql = process_lhs(self, compiler, connection)
199203
lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"}
200204
if timezone := self.get_tzname():

django_mongodb_backend/lookups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param):
4646
return sql, sql_params
4747

4848

49-
def in_(self, compiler, connection):
49+
def in_(self, compiler, connection, **extra):
5050
db_rhs = getattr(self.rhs, "_db", None)
5151
if db_rhs is not None and db_rhs != connection.alias:
5252
raise ValueError(
5353
"Subqueries aren't allowed across different databases. Force "
5454
"the inner query to be evaluated using `list(inner_query)`."
5555
)
56-
return builtin_lookup(self, compiler, connection)
56+
return builtin_lookup(self, compiler, connection, **extra)
5757

5858

5959
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001
@@ -91,7 +91,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
9191
def is_null(self, compiler, connection, as_expr=False):
9292
if not isinstance(self.rhs, bool):
9393
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
94-
if is_constant_value(self.rhs) and not as_expr:
94+
if is_constant_value(self.rhs) and not as_expr and is_simple_column(self.lhs):
9595
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
9696
return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs)
9797
lhs_mql = process_lhs(self, compiler, connection)

django_mongodb_backend/query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,9 @@ def where_node(self, compiler, connection, **extra):
329329
if not mql:
330330
raise FullResultSet
331331

332+
as_expr = extra.get("as_expr")
332333
if self.negated and mql:
333-
mql = {"$nor": [mql]}
334+
mql = {"$nor": [mql]} if not as_expr else {"$not": [mql]}
334335

335336
return mql
336337

0 commit comments

Comments
 (0)