Skip to content

Commit 8d2e794

Browse files
timgrahamWaVEV
authored andcommitted
EMF support json field.
1 parent 9b4083f commit 8d2e794

File tree

5 files changed

+65
-24
lines changed

5 files changed

+65
-24
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.db.models.lookups import Transform
88

99
from .. import forms
10+
from ..query_utils import key_transform_build_path
1011

1112

1213
class EmbeddedModelField(models.Field):
@@ -181,18 +182,24 @@ def get_transform(self, name):
181182
return result
182183

183184
def preprocess_lhs(self, compiler, connection):
184-
key_transforms = [self.key_name]
185-
previous = self.lhs
185+
previous = self
186+
embedded_key_transforms = []
187+
json_key_transforms = []
186188
while isinstance(previous, KeyTransform):
187-
key_transforms.insert(0, previous.key_name)
189+
if isinstance(previous.ref_field, EmbeddedModelField):
190+
embedded_key_transforms.insert(0, previous.key_name)
191+
else:
192+
json_key_transforms.insert(0, previous.key_name)
188193
previous = previous.lhs
189194
mql = previous.as_mql(compiler, connection)
190-
return mql, key_transforms
195+
embedded_key_transforms.append(json_key_transforms.pop(0))
196+
return mql, embedded_key_transforms, json_key_transforms
191197

192198
def as_mql(self, compiler, connection):
193-
mql, key_transforms = self.preprocess_lhs(compiler, connection)
199+
mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection)
194200
transforms = ".".join(key_transforms)
195-
return f"{mql}.{transforms}"
201+
result = f"{mql}.{transforms}"
202+
return key_transform_build_path(json_key_transforms, result)
196203

197204

198205
class KeyTransformFactory:

django_mongodb_backend/fields/json.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616
from ..lookups import builtin_lookup
17-
from ..query_utils import process_lhs, process_rhs
17+
from ..query_utils import key_transform_build_path, process_lhs, process_rhs
1818

1919

2020
def contained_by(self, compiler, connection): # noqa: ARG001
@@ -89,23 +89,7 @@ def key_transform(self, compiler, connection):
8989
key_transforms.insert(0, previous.key_name)
9090
previous = previous.lhs
9191
lhs_mql = previous.as_mql(compiler, connection)
92-
result = lhs_mql
93-
# Build the MQL path using the collected key transforms.
94-
for key in key_transforms:
95-
get_field = {"$getField": {"input": result, "field": key}}
96-
# Handle array indexing if the key is a digit. If key is something
97-
# like '001', it's not an array index despite isdigit() returning True.
98-
if key.isdigit() and str(int(key)) == key:
99-
result = {
100-
"$cond": {
101-
"if": {"$isArray": result},
102-
"then": {"$arrayElemAt": [result, int(key)]},
103-
"else": get_field,
104-
}
105-
}
106-
else:
107-
result = get_field
108-
return result
92+
return key_transform_build_path(key_transforms, lhs_mql)
10993

11094

11195
def key_transform_in(self, compiler, connection):

django_mongodb_backend/query_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,26 @@ def is_direct_value(node):
77
return not hasattr(node, "as_sql")
88

99

10+
def key_transform_build_path(key_trasnforms, lhs):
11+
# Build the MQL path using the collected key transforms.
12+
result = lhs
13+
for key in key_trasnforms:
14+
get_field = {"$getField": {"input": result, "field": key}}
15+
# Handle array indexing if the key is a digit. If key is something
16+
# like '001', it's not an array index despite isdigit() returning True.
17+
if key.isdigit() and str(int(key)) == key:
18+
result = {
19+
"$cond": {
20+
"if": {"$isArray": result},
21+
"then": {"$arrayElemAt": [result, int(key)]},
22+
"else": get_field,
23+
}
24+
}
25+
else:
26+
result = get_field
27+
return result
28+
29+
1030
def process_lhs(node, compiler, connection):
1131
if not hasattr(node, "lhs"):
1232
# node is a Func or Expression, possibly with multiple source expressions.

tests/model_fields_/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Data(EmbeddedModel):
103103
integer = models.IntegerField(db_column="custom_column")
104104
auto_now = models.DateTimeField(auto_now=True)
105105
auto_now_add = models.DateTimeField(auto_now_add=True)
106+
json_value = models.JSONField(default=dict)
106107

107108

108109
class Address(EmbeddedModel):

tests/model_fields_/test_embedded_model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,35 @@ def test_order_and_group_by_embedded_field(self):
139139
],
140140
)
141141

142+
def test_embedded_with_json_field(self):
143+
models = []
144+
for i in range(4):
145+
m = Holder.objects.create(
146+
data=Data(json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}})
147+
)
148+
models.append(m)
149+
150+
all_models = Holder.objects.all()
151+
152+
self.assertCountEqual(
153+
Holder.objects.filter(data__json_value__field2__0__value__0=0),
154+
models[1:],
155+
)
156+
self.assertCountEqual(
157+
Holder.objects.filter(data__json_value__field2__0__value__1=1),
158+
models[2:],
159+
)
160+
self.assertCountEqual(Holder.objects.filter(data__json_value__field2__0__value__1=5), [])
161+
162+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__lt=100), all_models)
163+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__gt=100), [])
164+
self.assertCountEqual(
165+
Holder.objects.filter(
166+
data__json_value__field1__gte=5, data__json_value__field1__lte=10
167+
),
168+
models[1:3],
169+
)
170+
142171
def test_order_and_group_by_embedded_field_annotation(self):
143172
# Create repeated `data__integer` values.
144173
[Holder.objects.create(data=Data(integer=x)) for x in range(6)]

0 commit comments

Comments
 (0)