Skip to content

Commit 8c0e138

Browse files
timgrahamWaVEV
authored andcommitted
support group by
1 parent fd57ce8 commit 8c0e138

File tree

3 files changed

+74
-6
lines changed

3 files changed

+74
-6
lines changed

django_mongodb_backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _get_column_from_expression(self, expr, alias):
5353
Create a column named `alias` from the given expression to hold the
5454
aggregate value.
5555
"""
56-
column_target = expr.output_field.__class__()
56+
column_target = expr.output_field.clone()
5757
column_target.db_column = alias
5858
column_target.set_attributes_from_name(alias)
5959
return Col(self.collection_name, column_target)
@@ -81,7 +81,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
8181
alias = (
8282
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
8383
)
84-
column_target = sub_expr.output_field.__class__()
84+
column_target = sub_expr.output_field.clone()
8585
column_target.db_column = alias
8686
column_target.set_attributes_from_name(alias)
8787
inner_column = Col(self.collection_name, column_target)

django_mongodb_backend/fields/embedded_model.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core import checks
2+
from django.core.exceptions import FieldDoesNotExist
23
from django.db import models
34
from django.db.models.fields.related import lazy_related_operation
45
from django.db.models.lookups import Transform
@@ -112,7 +113,8 @@ def get_transform(self, name):
112113
transform = super().get_transform(name)
113114
if transform:
114115
return transform
115-
return KeyTransformFactory(name)
116+
field = self.embedded_model._meta.get_field(name)
117+
return KeyTransformFactory(name, field)
116118

117119
def validate(self, value, model_instance):
118120
super().validate(value, model_instance)
@@ -134,9 +136,25 @@ def formfield(self, **kwargs):
134136

135137

136138
class KeyTransform(Transform):
137-
def __init__(self, key_name, *args, **kwargs):
139+
def __init__(self, key_name, ref_field, *args, **kwargs):
138140
super().__init__(*args, **kwargs)
139141
self.key_name = str(key_name)
142+
self.ref_field = ref_field
143+
144+
def get_transform(self, name):
145+
result = None
146+
if isinstance(self.ref_field, EmbeddedModelField):
147+
opts = self.ref_field.embedded_model._meta
148+
new_field = opts.get_field(name)
149+
result = KeyTransformFactory(name, new_field)
150+
else:
151+
if self.ref_field.get_transform(name) is None:
152+
raise FieldDoesNotExist(
153+
f"{self.ref_field.model._meta.object_name}.{self.ref_field.name}"
154+
f" has no field named '{name}'"
155+
)
156+
result = KeyTransformFactory(name, self.ref_field)
157+
return result
140158

141159
def preprocess_lhs(self, compiler, connection):
142160
key_transforms = [self.key_name]
@@ -154,8 +172,9 @@ def as_mql(self, compiler, connection):
154172

155173

156174
class KeyTransformFactory:
157-
def __init__(self, key_name):
175+
def __init__(self, key_name, ref_field):
158176
self.key_name = key_name
177+
self.ref_field = ref_field
159178

160179
def __call__(self, *args, **kwargs):
161-
return KeyTransform(self.key_name, *args, **kwargs)
180+
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)

tests/model_fields_/test_embedded_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
import operator
2+
13
from django.core.exceptions import ValidationError
24
from django.db import models
5+
from django.db.models import (
6+
ExpressionWrapper,
7+
F,
8+
IntegerField,
9+
Max,
10+
Sum,
11+
)
312
from django.test import SimpleTestCase, TestCase
413
from django.test.utils import isolate_apps
514

@@ -104,6 +113,46 @@ def test_nested(self):
104113
)
105114
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])
106115

116+
def truncate_ms(self, value):
117+
"""Truncate microsends to millisecond precision as supported by MongoDB."""
118+
return value.replace(microsecond=(value.microsecond // 1000) * 1000)
119+
120+
def test_ordering_by_embedded_field(self):
121+
query = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer").values("pk")
122+
expected = [{"pk": e.pk} for e in list(reversed(self.objs[4:]))]
123+
self.assertSequenceEqual(query, expected)
124+
125+
def test_ordering_grouping_by_embedded_field(self):
126+
expected = sorted(
127+
(Holder.objects.create(data=Data(integer=x)) for x in range(6)),
128+
key=lambda x: x.data.integer,
129+
)
130+
query = (
131+
Holder.objects.annotate(
132+
group=ExpressionWrapper(F("data__integer") + 5, output_field=IntegerField())
133+
)
134+
.values("group")
135+
.annotate(max_auto_now=Max("data__auto_now"))
136+
.order_by("data__integer")
137+
)
138+
query_response = [{**e, "max_auto_now": self.truncate_ms(e["max_auto_now"])} for e in query]
139+
self.assertSequenceEqual(
140+
query_response,
141+
[
142+
{"group": e.data.integer + 5, "max_auto_now": self.truncate_ms(e.data.auto_now)}
143+
for e in expected
144+
],
145+
)
146+
147+
def test_ordering_grouping_by_sum(self):
148+
[Holder.objects.create(data=Data(integer=x)) for x in range(6)]
149+
qs = (
150+
Holder.objects.values("data__integer")
151+
.annotate(sum=Sum("data__integer"))
152+
.order_by("sum")
153+
)
154+
self.assertQuerySetEqual(qs, [0, 2, 4, 6, 8, 10], operator.itemgetter("sum"))
155+
107156

108157
@isolate_apps("model_fields_")
109158
class CheckTests(SimpleTestCase):

0 commit comments

Comments
 (0)