diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index afb5d2526..c4ce0ab41 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -53,7 +53,7 @@ def _get_column_from_expression(self, expr, alias): Create a column named `alias` from the given expression to hold the aggregate value. """ - column_target = expr.output_field.__class__() + column_target = expr.output_field.clone() column_target.db_column = alias column_target.set_attributes_from_name(alias) return Col(self.collection_name, column_target) @@ -81,7 +81,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group alias = ( f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target ) - column_target = sub_expr.output_field.__class__() + column_target = sub_expr.output_field.clone() column_target.db_column = alias column_target.set_attributes_from_name(alias) inner_column = Col(self.collection_name, column_target) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index a75bd3ddd..10cd84f89 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -1,5 +1,8 @@ +import operator + from django.core.exceptions import ValidationError from django.db import models +from django.db.models import ExpressionWrapper, F, Max, Sum from django.test import SimpleTestCase, TestCase from django.test.utils import isolate_apps @@ -13,6 +16,7 @@ Data, Holder, ) +from .utils import truncate_ms class MethodTests(SimpleTestCase): @@ -38,10 +42,6 @@ def test_validate(self): class ModelTests(TestCase): - def truncate_ms(self, value): - """Truncate microseconds to milliseconds as supported by MongoDB.""" - return value.replace(microsecond=(value.microsecond // 1000) * 1000) - def test_save_load(self): Holder.objects.create(data=Data(integer="5")) obj = Holder.objects.get() @@ -64,12 +64,12 @@ def test_save_load_null(self): def test_pre_save(self): """Field.pre_save() is called on embedded model fields.""" obj = Holder.objects.create(data=Data()) - auto_now = self.truncate_ms(obj.data.auto_now) - auto_now_add = self.truncate_ms(obj.data.auto_now_add) + auto_now = truncate_ms(obj.data.auto_now) + auto_now_add = truncate_ms(obj.data.auto_now_add) self.assertEqual(auto_now, auto_now_add) # save() updates auto_now but not auto_now_add. obj.save() - self.assertEqual(self.truncate_ms(obj.data.auto_now_add), auto_now_add) + self.assertEqual(truncate_ms(obj.data.auto_now_add), auto_now_add) auto_now_two = obj.data.auto_now self.assertGreater(auto_now_two, obj.data.auto_now_add) # And again, save() updates auto_now but not auto_now_add. @@ -99,6 +99,47 @@ def test_gt(self): def test_gte(self): self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + def test_order_by_embedded_field(self): + qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer") + self.assertSequenceEqual(qs, list(reversed(self.objs[4:]))) + + def test_order_and_group_by_embedded_field(self): + # Create and sort test data by `data__integer`. + expected_objs = sorted( + (Holder.objects.create(data=Data(integer=x)) for x in range(6)), + key=lambda x: x.data.integer, + ) + # Group by `data__integer + 5` and get the latest `data__auto_now` + # datetime. + qs = ( + Holder.objects.annotate( + group=ExpressionWrapper(F("data__integer") + 5, output_field=models.IntegerField()), + ) + .values("group") + .annotate(max_auto_now=Max("data__auto_now")) + .order_by("data__integer") + ) + # Each unique `data__integer` is correctly grouped and annotated. + self.assertSequenceEqual( + [{**e, "max_auto_now": e["max_auto_now"]} for e in qs], + [ + {"group": e.data.integer + 5, "max_auto_now": truncate_ms(e.data.auto_now)} + for e in expected_objs + ], + ) + + def test_order_and_group_by_embedded_field_annotation(self): + # Create repeated `data__integer` values. + [Holder.objects.create(data=Data(integer=x)) for x in range(6)] + # Group by `data__integer` and compute the sum of occurrences. + qs = ( + Holder.objects.values("data__integer") + .annotate(sum=Sum("data__integer")) + .order_by("sum") + ) + # The sum is twice the integer values since each appears twice. + self.assertQuerySetEqual(qs, [0, 2, 4, 6, 8, 10], operator.itemgetter("sum")) + def test_nested(self): obj = Book.objects.create( author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) diff --git a/tests/model_fields_/utils.py b/tests/model_fields_/utils.py new file mode 100644 index 000000000..cf9dcc403 --- /dev/null +++ b/tests/model_fields_/utils.py @@ -0,0 +1,3 @@ +def truncate_ms(value): + """Truncate microseconds to milliseconds as supported by MongoDB.""" + return value.replace(microsecond=(value.microsecond // 1000) * 1000)