Skip to content

Commit 12969e4

Browse files
committed
Object-oriented approach solution
1 parent 57ca9fa commit 12969e4

File tree

13 files changed

+373
-309
lines changed

13 files changed

+373
-309
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,7 @@
88
MONGO_AGGREGATIONS = {Count: "sum"}
99

1010

11-
def aggregate(
12-
self,
13-
compiler,
14-
connection,
15-
operator=None,
16-
resolve_inner_expression=False,
17-
**extra_context, # noqa: ARG001
18-
):
11+
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
1912
if self.filter:
2013
node = self.copy()
2114
node.filter = None
@@ -31,7 +24,7 @@ def aggregate(
3124
return {f"${operator}": lhs_mql}
3225

3326

34-
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
27+
def count(self, compiler, connection, resolve_inner_expression=False):
3528
"""
3629
When resolve_inner_expression=True, return the MQL that resolves as a
3730
value. This is used to count different elements, so the inner values are
@@ -64,16 +57,16 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
6457
return {"$add": [{"$size": lhs_mql}, exits_null]}
6558

6659

67-
def stddev_variance(self, compiler, connection, **extra_context):
60+
def stddev_variance(self, compiler, connection):
6861
if self.function.endswith("_SAMP"):
6962
operator = "stdDevSamp"
7063
elif self.function.endswith("_POP"):
7164
operator = "stdDevPop"
72-
return aggregate(self, compiler, connection, operator=operator, **extra_context)
65+
return aggregate(self, compiler, connection, operator=operator)
7366

7467

7568
def register_aggregates():
76-
Aggregate.as_mql = aggregate
77-
Count.as_mql = count
78-
StdDev.as_mql = stddev_variance
79-
Variance.as_mql = stddev_variance
69+
Aggregate.as_mql_expr = aggregate
70+
Count.as_mql_expr = count
71+
StdDev.as_mql_expr = stddev_variance
72+
Variance.as_mql_expr = stddev_variance

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
709709
# For brevity/simplicity, project {"field_name": 1}
710710
# instead of {"field_name": "$field_name"}.
711711
if isinstance(expr, Col) and name == expr.target.column and not force_expression
712-
else expr.as_mql(self, self.connection, as_path=False)
712+
else expr.as_mql(self, self.connection)
713713
)
714714
except EmptyResultSet:
715715
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)

django_mongodb_backend/expressions/builtins.py

Lines changed: 37 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Exists,
1515
ExpressionList,
1616
ExpressionWrapper,
17-
Func,
1817
NegatedExpression,
1918
OrderBy,
2019
RawSQL,
@@ -25,14 +24,12 @@
2524
Value,
2625
When,
2726
)
28-
from django.db.models.fields.json import KeyTransform
2927
from django.db.models.sql import Query
3028

31-
from django_mongodb_backend.fields.array import Array
32-
from django_mongodb_backend.query_utils import is_direct_value, process_lhs
29+
from django_mongodb_backend.query_utils import process_lhs
3330

3431

35-
def case(self, compiler, connection, as_path=False):
32+
def case(self, compiler, connection):
3633
case_parts = []
3734
for case in self.cases:
3835
case_mql = {}
@@ -49,16 +46,12 @@ def case(self, compiler, connection, as_path=False):
4946
default_mql = self.default.as_mql(compiler, connection)
5047
if not case_parts:
5148
return default_mql
52-
expr = {
49+
return {
5350
"$switch": {
5451
"branches": case_parts,
5552
"default": default_mql,
5653
}
5754
}
58-
if as_path:
59-
return {"$expr": expr}
60-
61-
return expr
6255

6356

6457
def col(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -99,12 +92,12 @@ def combined_expression(self, compiler, connection, as_path=False):
9992
return connection.ops.combine_expression(self.connector, expressions)
10093

10194

102-
def expression_wrapper(self, compiler, connection, as_path=False):
103-
return self.expression.as_mql(compiler, connection, as_path=as_path)
95+
def expression_wrapper_expr(self, compiler, connection):
96+
return self.expression.as_mql(compiler, connection, as_path=False)
10497

10598

106-
def negated_expression(self, compiler, connection, as_path=False):
107-
return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)}
99+
def negated_expression_expr(self, compiler, connection):
100+
return {"$not": expression_wrapper_expr(self, compiler, connection)}
108101

109102

110103
def order_by(self, compiler, connection):
@@ -177,32 +170,26 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001
177170
return f"{prefix}{refs}"
178171

179172

180-
def star(self, compiler, connection, **extra): # noqa: ARG001
173+
@property
174+
def ref_is_simple_column(self):
175+
return isinstance(self.source, Col) and self.source.alias is not None
176+
177+
178+
def star(self, compiler, connection, as_path=False): # noqa: ARG001
181179
return {"$literal": True}
182180

183181

184-
def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
185-
expr = self.query.as_mql(
182+
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
183+
return self.query.as_mql(
186184
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False
187185
)
188-
if as_path:
189-
return {"$expr": expr}
190-
return expr
191186

192187

193-
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
188+
def exists(self, compiler, connection, get_wrapping_pipeline=None):
194189
try:
195-
lhs_mql = subquery(
196-
self,
197-
compiler,
198-
connection,
199-
get_wrapping_pipeline=get_wrapping_pipeline,
200-
as_path=as_path,
201-
)
190+
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
202191
except EmptyResultSet:
203192
return Value(False).as_mql(compiler, connection)
204-
if as_path:
205-
return {"$expr": connection.mongo_operators_match["isnull"](lhs_mql, False)}
206193
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
207194

208195

@@ -234,54 +221,37 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001
234221
return value
235222

236223

237-
@staticmethod
238-
def _is_constant_value(value):
239-
if isinstance(value, list | Array):
240-
iterable = value.get_source_expressions() if isinstance(value, Array) else value
241-
return all(_is_constant_value(e) for e in iterable)
242-
if is_direct_value(value):
243-
return True
244-
return isinstance(value, Func | Value) and not (
245-
value.contains_aggregate
246-
or value.contains_over_clause
247-
or value.contains_column_references
248-
or value.contains_subquery
249-
)
250-
251-
252-
@staticmethod
253-
def _is_simple_column(lhs):
254-
while isinstance(lhs, KeyTransform):
255-
if "." in getattr(lhs, "key_name", ""):
256-
return False
257-
lhs = lhs.lhs
258-
col = lhs.source if isinstance(lhs, Ref) else lhs
259-
# Foreign columns from parent cannot be addressed as single match
260-
return isinstance(col, Col) and col.alias is not None
261-
224+
def base_expression(self, compiler, connection, as_path=False, **extra):
225+
if (
226+
as_path
227+
and hasattr(self, "as_mql_path")
228+
and getattr(self, "is_simple_expression", lambda: False)()
229+
):
230+
return self.as_mql_path(compiler, connection, **extra)
262231

263-
def _is_simple_expression(self):
264-
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
232+
expr = self.as_mql_expr(compiler, connection, **extra)
233+
return {"$expr": expr} if as_path else expr
265234

266235

267236
def register_expressions():
268-
Case.as_mql = case
237+
BaseExpression.as_mql = base_expression
238+
BaseExpression.is_simple_column = False
239+
Case.as_mql_expr = case
269240
Col.as_mql = col
241+
Col.is_simple_column = True
270242
ColPairs.as_mql = col_pairs
271-
CombinedExpression.as_mql = combined_expression
272-
Exists.as_mql = exists
243+
CombinedExpression.as_mql_expr = combined_expression
244+
Exists.as_mql_expr = exists
273245
ExpressionList.as_mql = process_lhs
274-
ExpressionWrapper.as_mql = expression_wrapper
275-
NegatedExpression.as_mql = negated_expression
276-
OrderBy.as_mql = order_by
246+
ExpressionWrapper.as_mql_expr = expression_wrapper_expr
247+
NegatedExpression.as_mql_expr = negated_expression_expr
248+
OrderBy.as_mql_expr = order_by
277249
Query.as_mql = query
278250
RawSQL.as_mql = raw_sql
279251
Ref.as_mql = ref
252+
Ref.is_simple_column = ref_is_simple_column
280253
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
281254
Star.as_mql = star
282-
Subquery.as_mql = subquery
255+
Subquery.as_mql_expr = subquery
283256
When.as_mql = when
284257
Value.as_mql = value
285-
BaseExpression.is_simple_expression = _is_simple_expression
286-
BaseExpression.is_simple_column = _is_simple_column
287-
BaseExpression.is_constant_value = _is_constant_value

django_mongodb_backend/expressions/search.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -933,12 +933,15 @@ def __str__(self):
933933
def __repr__(self):
934934
return f"SearchText({self.lhs}, {self.rhs})"
935935

936-
def as_mql(self, compiler, connection, as_path=False):
937-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
938-
value = process_rhs(self, compiler, connection, as_path=as_path)
939-
if as_path:
940-
return {lhs_mql: {"$gte": value}}
941-
return {"$expr": {"$gte": [lhs_mql, value]}}
936+
def as_mql_expr(self, compiler, connection):
937+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
938+
value = process_rhs(self, compiler, connection, as_path=False)
939+
return {"$gte": [lhs_mql, value]}
940+
941+
def as_mql_path(self, compiler, connection):
942+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
943+
value = process_rhs(self, compiler, connection, as_path=True)
944+
return {lhs_mql: {"$gte": value}}
942945

943946

944947
CharField.register_lookup(SearchTextLookup)

0 commit comments

Comments
 (0)