Skip to content

Commit d3c042e

Browse files
timgrahamWaVEV
authored andcommitted
rebase from main, update the test, this commit should be overwritten with the refactor
1 parent 11754df commit d3c042e

File tree

4 files changed

+211
-13
lines changed

4 files changed

+211
-13
lines changed

django_mongodb_backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _get_column_from_expression(self, expr, alias):
5353
Create a column named `alias` from the given expression to hold the
5454
aggregate value.
5555
"""
56-
column_target = expr.output_field.__class__()
56+
column_target = expr.output_field.clone()
5757
column_target.db_column = alias
5858
column_target.set_attributes_from_name(alias)
5959
return Col(self.collection_name, column_target)
@@ -81,7 +81,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
8181
alias = (
8282
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
8383
)
84-
column_target = sub_expr.output_field.__class__()
84+
column_target = sub_expr.output_field.clone()
8585
column_target.db_column = alias
8686
column_target.set_attributes_from_name(alias)
8787
inner_column = Col(self.collection_name, column_target)

django_mongodb_backend/fields/embedded_model.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core import checks
2+
from django.core.exceptions import FieldDoesNotExist
23
from django.db import models
34
from django.db.models.fields.related import lazy_related_operation
45
from django.db.models.lookups import Transform
@@ -112,7 +113,8 @@ def get_transform(self, name):
112113
transform = super().get_transform(name)
113114
if transform:
114115
return transform
115-
return KeyTransformFactory(name)
116+
field = self.embedded_model._meta.get_field(name)
117+
return KeyTransformFactory(name, field)
116118

117119
def validate(self, value, model_instance):
118120
super().validate(value, model_instance)
@@ -134,28 +136,65 @@ def formfield(self, **kwargs):
134136

135137

136138
class KeyTransform(Transform):
137-
def __init__(self, key_name, *args, **kwargs):
139+
def __init__(self, key_name, ref_field, *args, **kwargs):
138140
super().__init__(*args, **kwargs)
139141
self.key_name = str(key_name)
142+
self.ref_field = ref_field
143+
144+
def get_transform(self, name):
145+
result = None
146+
if isinstance(self.ref_field, EmbeddedModelField):
147+
opts = self.ref_field.embedded_model._meta
148+
new_field = opts.get_field(name)
149+
result = KeyTransformFactory(name, new_field)
150+
else:
151+
if self.ref_field.get_transform(name) is None:
152+
raise FieldDoesNotExist(
153+
f"{self.ref_field.model._meta.object_name}.{self.ref_field.name}"
154+
f" has no field named '{name}'"
155+
)
156+
result = KeyTransformFactory(name, self.ref_field)
157+
return result
140158

141159
def preprocess_lhs(self, compiler, connection):
142-
key_transforms = [self.key_name]
143-
previous = self.lhs
160+
previous = self
161+
embedded_key_transforms = []
162+
json_key_transforms = []
144163
while isinstance(previous, KeyTransform):
145-
key_transforms.insert(0, previous.key_name)
164+
if isinstance(previous.ref_field, EmbeddedModelField):
165+
embedded_key_transforms.insert(0, previous.key_name)
166+
else:
167+
json_key_transforms.insert(0, previous.key_name)
146168
previous = previous.lhs
147169
mql = previous.as_mql(compiler, connection)
148-
return mql, key_transforms
170+
embedded_key_transforms.append(json_key_transforms.pop(0))
171+
return mql, embedded_key_transforms, json_key_transforms
149172

150173
def as_mql(self, compiler, connection):
151-
mql, key_transforms = self.preprocess_lhs(compiler, connection)
174+
mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection)
152175
transforms = ".".join(key_transforms)
153-
return f"{mql}.{transforms}"
176+
result = f"{mql}.{transforms}"
177+
for key in json_key_transforms:
178+
get_field = {"$getField": {"input": result, "field": key}}
179+
# Handle array indexing if the key is a digit. If key is something
180+
# like '001', it's not an array index despite isdigit() returning True.
181+
if key.isdigit() and str(int(key)) == key:
182+
result = {
183+
"$cond": {
184+
"if": {"$isArray": result},
185+
"then": {"$arrayElemAt": [result, int(key)]},
186+
"else": get_field,
187+
}
188+
}
189+
else:
190+
result = get_field
191+
return result
154192

155193

156194
class KeyTransformFactory:
157-
def __init__(self, key_name):
195+
def __init__(self, key_name, ref_field):
158196
self.key_name = key_name
197+
self.ref_field = ref_field
159198

160199
def __call__(self, *args, **kwargs):
161-
return KeyTransform(self.key_name, *args, **kwargs)
200+
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class Data(models.Model):
102102
integer = models.IntegerField(db_column="custom_column")
103103
auto_now = models.DateTimeField(auto_now=True)
104104
auto_now_add = models.DateTimeField(auto_now_add=True)
105+
json_value = models.JSONField(default=dict)
105106

106107

107108
class Address(models.Model):
@@ -119,3 +120,10 @@ class Author(models.Model):
119120
class Book(models.Model):
120121
name = models.CharField(max_length=100)
121122
author = EmbeddedModelField(Author)
123+
124+
125+
class Library(models.Model):
126+
name = models.CharField(max_length=100)
127+
books = models.ManyToManyField("Book", related_name="libraries")
128+
location = models.CharField(max_length=100, null=True, blank=True)
129+
best_seller = models.CharField(max_length=100, null=True, blank=True)

tests/model_fields_/test_embedded_model.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1-
from django.core.exceptions import ValidationError
1+
import operator
2+
3+
from django.core.exceptions import FieldDoesNotExist, ValidationError
24
from django.db import models
5+
from django.db.models import (
6+
Exists,
7+
ExpressionWrapper,
8+
F,
9+
IntegerField,
10+
Max,
11+
OuterRef,
12+
Subquery,
13+
Sum,
14+
)
315
from django.test import SimpleTestCase, TestCase
416
from django.test.utils import isolate_apps
517

@@ -11,6 +23,7 @@
1123
Book,
1224
Data,
1325
Holder,
26+
Library,
1427
)
1528

1629

@@ -104,6 +117,85 @@ def test_nested(self):
104117
)
105118
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])
106119

120+
def test_nested_not_exists(self):
121+
msg = "Address.city has no field named 'president'"
122+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
123+
Book.objects.filter(author__address__city__president="NYC")
124+
125+
def test_not_exists_in_embedded(self):
126+
msg = "Address has no field named 'floor'"
127+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
128+
Book.objects.filter(author__address__floor="NYC")
129+
130+
def test_embedded_with_json_field(self):
131+
models = []
132+
for i in range(4):
133+
m = Holder.objects.create(
134+
data=Data(json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}})
135+
)
136+
models.append(m)
137+
138+
all_models = Holder.objects.all()
139+
140+
self.assertCountEqual(
141+
Holder.objects.filter(data__json_value__field2__0__value__0=0),
142+
models[1:],
143+
)
144+
self.assertCountEqual(
145+
Holder.objects.filter(data__json_value__field2__0__value__1=1),
146+
models[2:],
147+
)
148+
self.assertCountEqual(Holder.objects.filter(data__json_value__field2__0__value__1=5), [])
149+
150+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__lt=100), all_models)
151+
self.assertCountEqual(Holder.objects.filter(data__json_value__field1__gt=100), [])
152+
self.assertCountEqual(
153+
Holder.objects.filter(
154+
data__json_value__field1__gte=5, data__json_value__field1__lte=10
155+
),
156+
models[1:3],
157+
)
158+
159+
def truncate_ms(self, value):
160+
"""Truncate microsends to millisecond precision as supported by MongoDB."""
161+
return value.replace(microsecond=(value.microsecond // 1000) * 1000)
162+
163+
def test_ordering_by_embedded_field(self):
164+
query = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer").values("pk")
165+
expected = [{"pk": e.pk} for e in list(reversed(self.objs[4:]))]
166+
self.assertSequenceEqual(query, expected)
167+
168+
def test_ordering_grouping_by_embedded_field(self):
169+
expected = sorted(
170+
(Holder.objects.create(data=Data(integer=x)) for x in range(6)),
171+
key=lambda x: x.data.integer,
172+
)
173+
query = (
174+
Holder.objects.annotate(
175+
group=ExpressionWrapper(F("data__integer") + 5, output_field=IntegerField())
176+
)
177+
.values("group")
178+
.annotate(max_auto_now=Max("data__auto_now"))
179+
.order_by("data__integer")
180+
)
181+
query_response = [{**e, "max_auto_now": self.truncate_ms(e["max_auto_now"])} for e in query]
182+
self.assertSequenceEqual(
183+
query_response,
184+
[
185+
{"group": e.data.integer + 5, "max_auto_now": self.truncate_ms(e.data.auto_now)}
186+
for e in expected
187+
],
188+
)
189+
190+
def test_ordering_grouping_by_sum(self):
191+
[Holder.objects.create(data=Data(integer=x)) for x in range(6)]
192+
qs = (
193+
Holder.objects.values("data__integer")
194+
.annotate(sum=Sum("data__integer"))
195+
.order_by("sum")
196+
)
197+
self.assertQuerySetEqual(qs, [0, 2, 4, 6, 8, 10], operator.itemgetter("sum"))
198+
107199

108200
@isolate_apps("model_fields_")
109201
class CheckTests(SimpleTestCase):
@@ -123,3 +215,62 @@ class MyModel(models.Model):
123215
self.assertEqual(
124216
msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)."
125217
)
218+
219+
220+
class SubqueryExistsTest(TestCase):
221+
def setUp(self):
222+
# Create test data
223+
address1 = Address.objects.create(city="New York", state="NY", zip_code=10001)
224+
address2 = Address.objects.create(city="Boston", state="MA", zip_code=20002)
225+
author1 = Author.objects.create(name="Alice", age=30, address=address1)
226+
author2 = Author.objects.create(name="Bob", age=40, address=address2)
227+
book1 = Book.objects.create(name="Book A", author=author1)
228+
book2 = Book.objects.create(name="Book B", author=author2)
229+
Book.objects.create(name="Book C", author=author2)
230+
Book.objects.create(name="Book D", author=author2)
231+
Book.objects.create(name="Book E", author=author1)
232+
233+
library1 = Library.objects.create(
234+
name="Central Library", location="Downtown", best_seller="Book A"
235+
)
236+
library2 = Library.objects.create(
237+
name="Community Library", location="Suburbs", best_seller="Book A"
238+
)
239+
240+
# Add books to libraries
241+
library1.books.add(book1, book2)
242+
library2.books.add(book2)
243+
244+
def test_exists_subquery(self):
245+
subquery = Book.objects.filter(
246+
author__name=OuterRef("name"), author__address__city="Boston"
247+
)
248+
queryset = Author.objects.filter(Exists(subquery))
249+
250+
self.assertEqual(queryset.count(), 1)
251+
252+
def test_in_subquery(self):
253+
subquery = Author.objects.filter(age__gt=35).values("name")
254+
queryset = Book.objects.filter(author__name__in=Subquery(subquery)).order_by("name")
255+
256+
self.assertEqual(queryset.count(), 3)
257+
self.assertQuerySetEqual(queryset, ["Book B", "Book C", "Book D"], lambda book: book.name)
258+
259+
def test_range_query(self):
260+
queryset = Author.objects.filter(age__range=(25, 45)).order_by("name")
261+
262+
self.assertEqual(queryset.count(), 2)
263+
self.assertQuerySetEqual(queryset, ["Alice", "Bob"], lambda author: author.name)
264+
265+
def test_exists_with_foreign_object(self):
266+
subquery = Library.objects.filter(best_seller=OuterRef("name"))
267+
queryset = Book.objects.filter(Exists(subquery))
268+
269+
self.assertEqual(queryset.count(), 1)
270+
self.assertEqual(queryset.first().name, "Book A")
271+
272+
def test_foreign_field_with_ranges(self):
273+
queryset = Library.objects.filter(books__author__age__range=(25, 35))
274+
275+
self.assertEqual(queryset.count(), 1)
276+
self.assertEqual(queryset.first().name, "Central Library")

0 commit comments

Comments
 (0)