Skip to content

Commit a344004

Browse files
committed
Handle array as path and update unit test.
1 parent 111b78c commit a344004

File tree

8 files changed

+370
-122
lines changed

8 files changed

+370
-122
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from django.db.models import Field, Func, IntegerField, Transform, Value
55
from django.db.models.fields.mixins import CheckFieldDefaultMixin
66
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup
7+
from django.utils.functional import cached_property
78
from django.utils.translation import gettext_lazy as _
89

910
from ..forms import SimpleArrayField
10-
from ..query_utils import process_lhs, process_rhs
11+
from ..query_utils import is_constant_value, process_lhs, process_rhs
1112
from ..utils import prefix_validation_error
1213
from ..validators import ArrayMaxLengthValidator, LengthValidator
1314

@@ -236,6 +237,20 @@ def as_mql_expr(self, compiler, connection):
236237
for expr in self.get_source_expressions()
237238
]
238239

240+
def as_mql_path(self, compiler, connection):
241+
return [
242+
expr.as_mql(compiler, connection, as_path=True)
243+
for expr in self.get_source_expressions()
244+
]
245+
246+
@cached_property
247+
def can_use_path(self):
248+
return all(is_constant_value(expr) for expr in self.get_source_expressions())
249+
250+
@property
251+
def is_simple_column(self):
252+
return False
253+
239254

240255
class ArrayRHSMixin:
241256
def __init__(self, lhs, rhs):

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from django.utils.functional import cached_property
99

1010
from .. import forms
11+
from ..query_utils import valid_path_key_name
1112

1213

1314
class EmbeddedModelField(models.Field):
@@ -174,8 +175,8 @@ def can_use_path(self):
174175
@cached_property
175176
def is_simple_column(self):
176177
previous = self
177-
while isinstance(previous, KeyTransform):
178-
if not previous.key_name.isalnum():
178+
while isinstance(previous, EmbeddedModelTransform):
179+
if not valid_path_key_name(previous._field.column):
179180
return False
180181
previous = previous.lhs
181182
return previous.is_simple_column

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .. import forms
1111
from ..lookups import builtin_lookup_path
12-
from ..query_utils import process_lhs, process_rhs
12+
from ..query_utils import process_lhs, process_rhs, valid_path_key_name
1313
from . import EmbeddedModelField
1414
from .array import ArrayField, ArrayLenTransform
1515

@@ -240,6 +240,7 @@ def __init__(self, field, *args, **kwargs):
240240
column_name = f"$item.{field.column}"
241241
column_target.db_column = column_name
242242
column_target.set_attributes_from_name(column_name)
243+
self._field = field
243244
self._lhs = Col(None, column_target)
244245
self._sub_transform = None
245246

@@ -255,7 +256,7 @@ def can_use_path(self):
255256
def is_simple_column(self):
256257
previous = self
257258
while isinstance(previous, EmbeddedModelArrayFieldTransform):
258-
if not previous.key_name.isalnum():
259+
if not valid_path_key_name(previous._field.column):
259260
return False
260261
previous = previous.lhs
261262
return previous.is_simple_column and self._lhs.is_simple_column

django_mongodb_backend/fields/json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020
from ..lookups import builtin_lookup_expr, builtin_lookup_path
21-
from ..query_utils import process_lhs, process_rhs
21+
from ..query_utils import process_lhs, process_rhs, valid_path_key_name
2222

2323

2424
def build_json_mql_path(lhs, key_transforms, as_path=False):
@@ -75,7 +75,7 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False):
7575
@property
7676
def has_key_check_simple_expression(self):
7777
rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs
78-
return self.is_simple_column and all(key.isalnum() for key in rhs)
78+
return self.is_simple_column and all(valid_path_key_name(key) for key in rhs)
7979

8080

8181
def has_key_lookup(self, compiler, connection, as_path=False):
@@ -231,7 +231,7 @@ def key_transform_numeric_lookup_mixin_path(self, compiler, connection):
231231
def keytransform_is_simple_column(self):
232232
previous = self
233233
while isinstance(previous, KeyTransform):
234-
if not previous.key_name.isalnum():
234+
if not valid_path_key_name(previous.key_name):
235235
return False
236236
previous = previous.lhs
237237
return previous.is_simple_column

django_mongodb_backend/query_utils.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import re
2+
13
from django.core.exceptions import FullResultSet
24
from django.db.models.aggregates import Aggregate
3-
from django.db.models.expressions import CombinedExpression, Value
5+
from django.db.models.expressions import CombinedExpression, Func, Value
46
from django.db.models.sql.query import Query
57

68

@@ -74,17 +76,24 @@ def is_constant_value(value):
7476
if hasattr(value, "get_source_expressions"):
7577
# Temporary: similar limitation as above, sub-expressions should be
7678
# resolved in the future
77-
simple_sub_expressions = all(map(is_constant_value, value.get_source_expressions()))
79+
constants_sub_expressions = all(map(is_constant_value, value.get_source_expressions()))
7880
else:
79-
simple_sub_expressions = True
80-
return (
81-
simple_sub_expressions
82-
and isinstance(value, Value)
83-
and not (
84-
isinstance(value, Query)
85-
or value.contains_aggregate
86-
or value.contains_over_clause
87-
or value.contains_column_references
88-
or value.contains_subquery
89-
)
81+
constants_sub_expressions = True
82+
constants_sub_expressions = constants_sub_expressions and not (
83+
isinstance(value, Query)
84+
or value.contains_aggregate
85+
or value.contains_over_clause
86+
or value.contains_column_references
87+
or value.contains_subquery
88+
)
89+
return constants_sub_expressions and (
90+
isinstance(value, Value)
91+
or
92+
# Some closed functions cannot yet be converted to constant values.
93+
# Allow Func with can_use_path as a temporary exception.
94+
(isinstance(value, Func) and value.can_use_path)
9095
)
96+
97+
98+
def valid_path_key_name(key_name):
99+
return bool(re.fullmatch(r"[A-Za-z0-9_]+", key_name))

0 commit comments

Comments
 (0)