Skip to content

Commit 5b99790

Browse files
committed
beginnings of querying
1 parent 5a442f8 commit 5b99790

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

django_mongodb/fields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from .auto import ObjectIdAutoField
22
from .duration import register_duration_field
3-
from .embedded_model import EmbeddedModelField
3+
from .embedded_model import EmbeddedModelField, register_embedded_model_field
44
from .json import register_json_field
55

66
__all__ = ["register_fields", "EmbeddedModelField", "ObjectIdAutoField"]
77

88

99
def register_fields():
1010
register_duration_field()
11+
register_embedded_model_field()
1112
register_json_field()

django_mongodb/fields/embedded_model.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.db import models
22
from django.db.models.fields.related import lazy_related_operation
3+
from django.db.models.lookups import Transform
34

45

56
class EmbeddedModelField(models.Field):
@@ -108,10 +109,48 @@ def get_db_prep_save(self, embedded_instance, connection):
108109
embedded_instance._state.adding = False
109110
return field_values
110111

112+
def get_transform(self, name):
113+
transform = super().get_transform(name)
114+
if transform:
115+
return transform
116+
return KeyTransformFactory(name)
117+
111118
def validate(self, value, model_instance):
112119
super().validate(value, model_instance)
113120
if self.embedded_model is None:
114121
return
115122
for field in self.embedded_model._meta.fields:
116123
attname = field.attname
117124
field.validate(getattr(value, attname), model_instance)
125+
126+
127+
class KeyTransform(Transform):
128+
def __init__(self, key_name, *args, **kwargs):
129+
super().__init__(*args, **kwargs)
130+
self.key_name = str(key_name)
131+
132+
def preprocess_lhs(self, compiler, connection):
133+
key_transforms = [self.key_name]
134+
previous = self.lhs
135+
while isinstance(previous, KeyTransform):
136+
key_transforms.insert(0, previous.key_name)
137+
previous = previous.lhs
138+
mql = previous.as_mql(compiler, connection)
139+
return mql, key_transforms
140+
141+
142+
def key_transform(self, compiler, connection):
143+
mql, key_transforms = self.preprocess_lhs(compiler, connection)
144+
return f"{mql}.{key_transforms[0]}"
145+
146+
147+
class KeyTransformFactory:
148+
def __init__(self, key_name):
149+
self.key_name = key_name
150+
151+
def __call__(self, *args, **kwargs):
152+
return KeyTransform(self.key_name, *args, **kwargs)
153+
154+
155+
def register_embedded_model_field():
156+
KeyTransform.as_mql = key_transform

tests/model_fields_/test_embedded_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,37 @@ def test_embedded_field_with_foreign_conversion(self):
9191
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
9292
decimal_parent = DecimalParent.objects.create(child=decimal)
9393
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)
94+
95+
96+
class QueryingTests(TestCase):
97+
@classmethod
98+
def setUpTestData(cls):
99+
cls.objs = [
100+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint=x))
101+
for x in range(6)
102+
]
103+
104+
def test_exact(self):
105+
self.assertCountEqual(
106+
EmbeddedModelFieldModel.objects.filter(simple__someint=3), [self.objs[3]]
107+
)
108+
109+
def test_lt(self):
110+
self.assertCountEqual(
111+
EmbeddedModelFieldModel.objects.filter(simple__someint__lt=3), self.objs[:3]
112+
)
113+
114+
def test_lte(self):
115+
self.assertCountEqual(
116+
EmbeddedModelFieldModel.objects.filter(simple__someint__lte=3), self.objs[:4]
117+
)
118+
119+
def test_gt(self):
120+
self.assertCountEqual(
121+
EmbeddedModelFieldModel.objects.filter(simple__someint__gt=3), self.objs[4:]
122+
)
123+
124+
def test_gte(self):
125+
self.assertCountEqual(
126+
EmbeddedModelFieldModel.objects.filter(simple__someint__gte=3), self.objs[3:]
127+
)

0 commit comments

Comments
 (0)