Skip to content

Commit 9ac8839

Browse files
committed
querying support
1 parent 9c24d47 commit 9ac8839

File tree

3 files changed

+101
-18
lines changed

3 files changed

+101
-18
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.db import models
22
from django.db.models.fields.related import lazy_related_operation
3+
from django.db.models.lookups import Transform
34

45

56
class EmbeddedModelField(models.Field):
@@ -31,13 +32,8 @@ def _set_model(self, model):
3132
Resolve embedded model class once the field knows the model it belongs
3233
to.
3334
34-
If the model argument passed to __init__() was a string, resolve that
35-
string to the corresponding model class, similar to relation fields.
36-
However, we need to know our own model to generate a valid key
37-
for the embedded model class lookup and EmbeddedModelFields are
38-
not contributed_to_class if used in iterable fields. Thus the
39-
collection field sets this field's "model" attribute in its
40-
contribute_to_class().
35+
If __init__()'s embedded_model argument is a string, resolve it to the
36+
corresponding model class, similar to relation fields.
4137
"""
4238
self._model = model
4339
if model is not None and isinstance(self.embedded_model, str):
@@ -54,8 +50,8 @@ def from_db_value(self, value, expression, connection):
5450

5551
def to_python(self, value):
5652
"""
57-
Passes embedded model fields' values through embedded fields
58-
to_python() and reinstiatates the embedded instance.
53+
Pass embedded model fields' values through each field's to_python() and
54+
reinstiatate the embedded instance.
5955
"""
6056
if value is None:
6157
return None
@@ -76,14 +72,8 @@ def to_python(self, value):
7672

7773
def get_db_prep_save(self, embedded_instance, connection):
7874
"""
79-
Apply pre_save() and get_db_prep_save() of embedded instance
80-
fields and passes a field => value mapping down to database
81-
type conversions.
82-
83-
The embedded instance will be saved as a column => value dict, but
84-
because we need to apply database type conversions on embedded instance
85-
fields' values and for these we need to know fields those values come
86-
from, we need to entrust the database layer with creating the dict.
75+
Apply pre_save() and get_db_prep_save() of embedded instance fields and
76+
create the {field: value} dict to be saved.
8777
"""
8878
if embedded_instance is None:
8979
return None
@@ -106,14 +96,48 @@ def get_db_prep_save(self, embedded_instance, connection):
10696
continue
10797
field_values[field.attname] = value
10898
# This instance will exist in the database soon.
109-
# TODO.XXX: Ensure that this doesn't cause race conditions.
99+
# TODO: Ensure that this doesn't cause race conditions.
110100
embedded_instance._state.adding = False
111101
return field_values
112102

103+
def get_transform(self, name):
104+
transform = super().get_transform(name)
105+
if transform:
106+
return transform
107+
return KeyTransformFactory(name)
108+
113109
def validate(self, value, model_instance):
114110
super().validate(value, model_instance)
115111
if self.embedded_model is None:
116112
return
117113
for field in self.embedded_model._meta.fields:
118114
attname = field.attname
119115
field.validate(getattr(value, attname), model_instance)
116+
117+
118+
class KeyTransform(Transform):
119+
def __init__(self, key_name, *args, **kwargs):
120+
super().__init__(*args, **kwargs)
121+
self.key_name = str(key_name)
122+
123+
def preprocess_lhs(self, compiler, connection):
124+
key_transforms = [self.key_name]
125+
previous = self.lhs
126+
while isinstance(previous, KeyTransform):
127+
key_transforms.insert(0, previous.key_name)
128+
previous = previous.lhs
129+
mql = previous.as_mql(compiler, connection)
130+
return mql, key_transforms
131+
132+
def as_mql(self, compiler, connection):
133+
mql, key_transforms = self.preprocess_lhs(compiler, connection)
134+
transforms = ".".join(key_transforms)
135+
return f"{mql}.{transforms}"
136+
137+
138+
class KeyTransformFactory:
139+
def __init__(self, key_name):
140+
self.key_name = key_name
141+
142+
def __call__(self, *args, **kwargs):
143+
return KeyTransform(self.key_name, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,19 @@ class EmbeddedModel(models.Model):
120120
someint = models.IntegerField(db_column="custom_column")
121121
auto_now = models.DateTimeField(auto_now=True)
122122
auto_now_add = models.DateTimeField(auto_now_add=True)
123+
124+
125+
class Address(models.Model):
126+
city = models.CharField(max_length=20)
127+
state = models.CharField(max_length=2)
128+
129+
130+
class Author(models.Model):
131+
name = models.CharField(max_length=10)
132+
age = models.IntegerField()
133+
address = EmbeddedModelField(Address)
134+
135+
136+
class Book(models.Model):
137+
name = models.CharField(max_length=100)
138+
author = EmbeddedModelField(Author)

tests/model_fields_/test_embedded_model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from django_mongodb_backend.fields import EmbeddedModelField
77

88
from .models import (
9+
Address,
10+
Author,
11+
Book,
912
DecimalKey,
1013
DecimalParent,
1114
EmbeddedModel,
@@ -91,3 +94,43 @@ def test_embedded_field_with_foreign_conversion(self):
9194
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
9295
decimal_parent = DecimalParent.objects.create(child=decimal)
9396
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)
97+
98+
99+
class QueryingTests(TestCase):
100+
@classmethod
101+
def setUpTestData(cls):
102+
cls.objs = [
103+
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint=x))
104+
for x in range(6)
105+
]
106+
107+
def test_exact(self):
108+
self.assertCountEqual(
109+
EmbeddedModelFieldModel.objects.filter(simple__someint=3), [self.objs[3]]
110+
)
111+
112+
def test_lt(self):
113+
self.assertCountEqual(
114+
EmbeddedModelFieldModel.objects.filter(simple__someint__lt=3), self.objs[:3]
115+
)
116+
117+
def test_lte(self):
118+
self.assertCountEqual(
119+
EmbeddedModelFieldModel.objects.filter(simple__someint__lte=3), self.objs[:4]
120+
)
121+
122+
def test_gt(self):
123+
self.assertCountEqual(
124+
EmbeddedModelFieldModel.objects.filter(simple__someint__gt=3), self.objs[4:]
125+
)
126+
127+
def test_gte(self):
128+
self.assertCountEqual(
129+
EmbeddedModelFieldModel.objects.filter(simple__someint__gte=3), self.objs[3:]
130+
)
131+
132+
def test_nested(self):
133+
obj = Book.objects.create(
134+
author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
135+
)
136+
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])

0 commit comments

Comments
 (0)