Skip to content

Commit 6a87d84

Browse files
WaVEVtimgraham
authored andcommitted
add support for JSONField lookups in an embedded model
1 parent 6343462 commit 6a87d84

File tree

4 files changed

+63
-23
lines changed

4 files changed

+63
-23
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.db.models.lookups import Transform
88

99
from .. import forms
10+
from .json import build_json_mql_path
1011

1112

1213
class EmbeddedModelField(models.Field):
@@ -181,18 +182,27 @@ def get_transform(self, name):
181182
return result
182183

183184
def preprocess_lhs(self, compiler, connection):
184-
key_transforms = [self.key_name]
185-
previous = self.lhs
185+
previous = self
186+
embedded_key_transforms = []
187+
json_key_transforms = []
186188
while isinstance(previous, KeyTransform):
187-
key_transforms.insert(0, previous.key_name)
189+
if isinstance(previous.ref_field, EmbeddedModelField):
190+
embedded_key_transforms.insert(0, previous.key_name)
191+
else:
192+
json_key_transforms.insert(0, previous.key_name)
188193
previous = previous.lhs
189194
mql = previous.as_mql(compiler, connection)
190-
return mql, key_transforms
195+
# The first json_key_transform is the field name.
196+
embedded_key_transforms.append(json_key_transforms.pop(0))
197+
return mql, embedded_key_transforms, json_key_transforms
191198

192199
def as_mql(self, compiler, connection):
193-
mql, key_transforms = self.preprocess_lhs(compiler, connection)
200+
mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection)
194201
transforms = ".".join(key_transforms)
195-
return f"{mql}.{transforms}"
202+
result = f"{mql}.{transforms}"
203+
if json_key_transforms:
204+
result = build_json_mql_path(result, json_key_transforms)
205+
return result
196206

197207

198208
class KeyTransformFactory:

django_mongodb_backend/fields/json.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,26 @@
1717
from ..query_utils import process_lhs, process_rhs
1818

1919

20+
def build_json_mql_path(lhs, key_transforms):
21+
# Build the MQL path using the collected key transforms.
22+
result = lhs
23+
for key in key_transforms:
24+
get_field = {"$getField": {"input": result, "field": key}}
25+
# Handle array indexing if the key is a digit. If key is something
26+
# like '001', it's not an array index despite isdigit() returning True.
27+
if key.isdigit() and str(int(key)) == key:
28+
result = {
29+
"$cond": {
30+
"if": {"$isArray": result},
31+
"then": {"$arrayElemAt": [result, int(key)]},
32+
"else": get_field,
33+
}
34+
}
35+
else:
36+
result = get_field
37+
return result
38+
39+
2040
def contained_by(self, compiler, connection): # noqa: ARG001
2141
raise NotSupportedError("contained_by lookup is not supported on this database backend.")
2242

@@ -89,23 +109,7 @@ def key_transform(self, compiler, connection):
89109
key_transforms.insert(0, previous.key_name)
90110
previous = previous.lhs
91111
lhs_mql = previous.as_mql(compiler, connection)
92-
result = lhs_mql
93-
# Build the MQL path using the collected key transforms.
94-
for key in key_transforms:
95-
get_field = {"$getField": {"input": result, "field": key}}
96-
# Handle array indexing if the key is a digit. If key is something
97-
# like '001', it's not an array index despite isdigit() returning True.
98-
if key.isdigit() and str(int(key)) == key:
99-
result = {
100-
"$cond": {
101-
"if": {"$isArray": result},
102-
"then": {"$arrayElemAt": [result, int(key)]},
103-
"else": get_field,
104-
}
105-
}
106-
else:
107-
result = get_field
108-
return result
112+
return build_json_mql_path(lhs_mql, key_transforms)
109113

110114

111115
def key_transform_in(self, compiler, connection):

tests/model_fields_/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Data(EmbeddedModel):
103103
integer = models.IntegerField(db_column="custom_column")
104104
auto_now = models.DateTimeField(auto_now=True)
105105
auto_now_add = models.DateTimeField(auto_now_add=True)
106+
json_value = models.JSONField()
106107

107108

108109
class Address(EmbeddedModel):

tests/model_fields_/test_embedded_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,31 @@ def test_order_by_embedded_field(self):
114114
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
115115
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))
116116

117+
def test_embedded_json_field_lookups(self):
118+
objs = [
119+
Holder.objects.create(
120+
data=Data(json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}})
121+
)
122+
for i in range(4)
123+
]
124+
self.assertCountEqual(
125+
Holder.objects.filter(data__json_value__field2__0__value__0=0),
126+
objs[1:],
127+
)
128+
self.assertCountEqual(
129+
Holder.objects.filter(data__json_value__field2__0__value__1=1),
130+
objs[2:],
131+
)
132+
self.assertCountEqual(Holder.objects.filter(data__json_value__field2__0__value__1=5), [])
133+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__lt=100), objs)
134+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__gt=100), [])
135+
self.assertCountEqual(
136+
Holder.objects.filter(
137+
data__json_value__field1__gte=5, data__json_value__field1__lte=10
138+
),
139+
objs[1:3],
140+
)
141+
117142
def test_order_and_group_by_embedded_field(self):
118143
# Create and sort test data by `data__integer`.
119144
expected_objs = sorted(

0 commit comments

Comments
 (0)