Skip to content

Commit 0c1cbac

Browse files
committed
edits
1 parent ae4a08a commit 0c1cbac

File tree

5 files changed

+37
-34
lines changed

5 files changed

+37
-34
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.db.models.lookups import Transform
88

99
from .. import forms
10-
from ..query_utils import key_transform_build_path
10+
from .json import build_json_mql_path
1111

1212

1313
class EmbeddedModelField(models.Field):
@@ -192,14 +192,17 @@ def preprocess_lhs(self, compiler, connection):
192192
json_key_transforms.insert(0, previous.key_name)
193193
previous = previous.lhs
194194
mql = previous.as_mql(compiler, connection)
195+
# The first json_key_transform is the field name.
195196
embedded_key_transforms.append(json_key_transforms.pop(0))
196197
return mql, embedded_key_transforms, json_key_transforms
197198

198199
def as_mql(self, compiler, connection):
199200
mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection)
200201
transforms = ".".join(key_transforms)
201202
result = f"{mql}.{transforms}"
202-
return key_transform_build_path(json_key_transforms, result)
203+
if json_key_transforms:
204+
result = build_json_mql_path(result, json_key_transforms)
205+
return result
203206

204207

205208
class KeyTransformFactory:

django_mongodb_backend/fields/json.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,27 @@
1414
)
1515

1616
from ..lookups import builtin_lookup
17-
from ..query_utils import key_transform_build_path, process_lhs, process_rhs
17+
from ..query_utils import process_lhs, process_rhs
18+
19+
20+
def build_json_mql_path(lhs, key_trasnforms):
21+
# Build the MQL path using the collected key transforms.
22+
result = lhs
23+
for key in key_trasnforms:
24+
get_field = {"$getField": {"input": result, "field": key}}
25+
# Handle array indexing if the key is a digit. If key is something
26+
# like '001', it's not an array index despite isdigit() returning True.
27+
if key.isdigit() and str(int(key)) == key:
28+
result = {
29+
"$cond": {
30+
"if": {"$isArray": result},
31+
"then": {"$arrayElemAt": [result, int(key)]},
32+
"else": get_field,
33+
}
34+
}
35+
else:
36+
result = get_field
37+
return result
1838

1939

2040
def contained_by(self, compiler, connection): # noqa: ARG001
@@ -89,7 +109,7 @@ def key_transform(self, compiler, connection):
89109
key_transforms.insert(0, previous.key_name)
90110
previous = previous.lhs
91111
lhs_mql = previous.as_mql(compiler, connection)
92-
return key_transform_build_path(key_transforms, lhs_mql)
112+
return build_json_mql_path(lhs_mql, key_transforms)
93113

94114

95115
def key_transform_in(self, compiler, connection):

django_mongodb_backend/query_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,6 @@ def is_direct_value(node):
77
return not hasattr(node, "as_sql")
88

99

10-
def key_transform_build_path(key_trasnforms, lhs):
11-
# Build the MQL path using the collected key transforms.
12-
result = lhs
13-
for key in key_trasnforms:
14-
get_field = {"$getField": {"input": result, "field": key}}
15-
# Handle array indexing if the key is a digit. If key is something
16-
# like '001', it's not an array index despite isdigit() returning True.
17-
if key.isdigit() and str(int(key)) == key:
18-
result = {
19-
"$cond": {
20-
"if": {"$isArray": result},
21-
"then": {"$arrayElemAt": [result, int(key)]},
22-
"else": get_field,
23-
}
24-
}
25-
else:
26-
result = get_field
27-
return result
28-
29-
3010
def process_lhs(node, compiler, connection):
3111
if not hasattr(node, "lhs"):
3212
# node is a Func or Expression, possibly with multiple source expressions.

tests/model_fields_/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class Data(EmbeddedModel):
103103
integer = models.IntegerField(db_column="custom_column")
104104
auto_now = models.DateTimeField(auto_now=True)
105105
auto_now_add = models.DateTimeField(auto_now_add=True)
106-
json_value = models.JSONField(default=None)
106+
json_value = models.JSONField()
107107

108108

109109
class Address(EmbeddedModel):

tests/model_fields_/test_embedded_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,29 +139,29 @@ def test_order_and_group_by_embedded_field(self):
139139
],
140140
)
141141

142-
def test_embedded_with_json_field(self):
143-
models = []
144-
for i in range(4):
145-
m = Holder.objects.create(
142+
def test_embedded_json_field(self):
143+
objs = [
144+
Holder.objects.create(
146145
data=Data(json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}})
147146
)
148-
models.append(m)
147+
for i in range(4)
148+
]
149149
self.assertCountEqual(
150150
Holder.objects.filter(data__json_value__field2__0__value__0=0),
151-
models[1:],
151+
objs[1:],
152152
)
153153
self.assertCountEqual(
154154
Holder.objects.filter(data__json_value__field2__0__value__1=1),
155-
models[2:],
155+
objs[2:],
156156
)
157157
self.assertCountEqual(Holder.objects.filter(data__json_value__field2__0__value__1=5), [])
158-
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__lt=100), models)
158+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__lt=100), objs)
159159
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__gt=100), [])
160160
self.assertCountEqual(
161161
Holder.objects.filter(
162162
data__json_value__field1__gte=5, data__json_value__field1__lte=10
163163
),
164-
models[1:3],
164+
objs[1:3],
165165
)
166166

167167
def test_order_and_group_by_embedded_field_annotation(self):

0 commit comments

Comments
 (0)