Skip to content

Commit a09fa5f

Browse files
committed
Support index definition on Embedded Models in top level model.
1 parent 41580cb commit a09fa5f

File tree

5 files changed

+245
-6
lines changed

5 files changed

+245
-6
lines changed

django_mongodb_backend/indexes.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from django.core.checks import Error, Warning
55
from django.db import NotSupportedError
6+
from django.db.backends.utils import names_digest, split_identifier
67
from django.db.models import FloatField, Index, IntegerField
78
from django.db.models.lookups import BuiltinLookup
89
from django.db.models.sql.query import Query
@@ -13,6 +14,7 @@
1314
from django_mongodb_backend.fields import ArrayField
1415

1516
from .query_utils import process_rhs
17+
from .utils import get_column_path
1618

1719
MONGO_INDEX_OPERATORS = {
1820
"exact": "$eq",
@@ -61,7 +63,7 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False
6163
filter_expression[column].update({"$type": field.db_type(schema_editor.connection)})
6264
else:
6365
for field_name, _ in self.fields_orders:
64-
field_ = model._meta.get_field(field_name)
66+
field_ = get_column_path(model, field_name)
6567
filter_expression[field_.column].update(
6668
{"$type": field_.db_type(schema_editor.connection)}
6769
)
@@ -74,7 +76,7 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False
7476
# order is "" if ASCENDING or "DESC" if DESCENDING (see
7577
# django.db.models.indexes.Index.fields_orders).
7678
(
77-
column_prefix + model._meta.get_field(field_name).column,
79+
column_prefix + get_column_path(model, field_name).column,
7880
ASCENDING if order == "" else DESCENDING,
7981
)
8082
for field_name, order in self.fields_orders
@@ -154,7 +156,7 @@ def get_pymongo_index_model(
154156
for field_name, _ in self.fields_orders:
155157
field = model._meta.get_field(field_name)
156158
type_ = self.search_index_data_types(field.db_type(schema_editor.connection))
157-
field_path = column_prefix + model._meta.get_field(field_name).column
159+
field_path = column_prefix + get_column_path(model, field_name)
158160
fields[field_path] = {"type": type_}
159161
return SearchIndexModel(
160162
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
@@ -264,7 +266,7 @@ def get_pymongo_index_model(
264266
fields = []
265267
for field_name, _ in self.fields_orders:
266268
field_ = model._meta.get_field(field_name)
267-
field_path = column_prefix + model._meta.get_field(field_name).column
269+
field_path = column_prefix + get_column_path(model, field_name)
268270
mappings = {"path": field_path}
269271
if isinstance(field_, ArrayField):
270272
mappings.update(
@@ -280,8 +282,40 @@ def get_pymongo_index_model(
280282
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")
281283

282284

285+
def set_name_with_model(self, model):
286+
"""
287+
Generate a unique name for the index.
288+
289+
The name is divided into 3 parts - table name (12 chars), field name
290+
(8 chars) and unique hash + suffix (10 chars). Each part is made to
291+
fit its size by truncating the excess length.
292+
"""
293+
_, table_name = split_identifier(model._meta.db_table)
294+
column_names = [
295+
get_column_path(model, field_name).column for field_name, order in self.fields_orders
296+
]
297+
column_names_with_order = [
298+
(f"-{column_name}" if order else column_name)
299+
for column_name, (field_name, order) in zip(column_names, self.fields_orders, strict=False)
300+
]
301+
# The length of the parts of the name is based on the default max
302+
# length of 30 characters.
303+
hash_data = [table_name, *column_names_with_order, self.suffix]
304+
self.name = (
305+
f"{table_name[:11]}_{column_names[0][:7]}_"
306+
f"{names_digest(*hash_data, length=6)}_{self.suffix}"
307+
)
308+
if len(self.name) > self.max_name_length:
309+
raise ValueError(
310+
"Index too long for multiple database support. Is self.suffix longer than 3 characters?"
311+
)
312+
if self.name[0] == "_" or self.name[0].isdigit():
313+
self.name = f"D{self.name[1:]}"
314+
315+
283316
def register_indexes():
284317
BuiltinLookup.as_mql_idx = builtin_lookup_idx
285318
Index._get_condition_mql = _get_condition_mql
286319
Index.get_pymongo_index_model = get_pymongo_index_model
320+
Index.set_name_with_model = set_name_with_model
287321
WhereNode.as_mql_idx = where_node_idx

django_mongodb_backend/lookups.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import django.db.models.base as base
12
from django.db import NotSupportedError
3+
from django.db.models.constants import LOOKUP_SEP
24
from django.db.models.fields.related_lookups import In, RelatedIn
35
from django.db.models.lookups import (
46
BuiltinLookup,
@@ -8,6 +10,7 @@
810
UUIDTextMixin,
911
)
1012

13+
from .fields import EmbeddedModelField
1114
from .query_utils import process_lhs, process_rhs
1215

1316

@@ -121,6 +124,26 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
121124
raise NotSupportedError("Pattern lookups on UUIDField are not supported.")
122125

123126

127+
class Options(base.Options):
128+
def get_field(self, field_name):
129+
if LOOKUP_SEP in field_name:
130+
previous = self
131+
keys = field_name.split(LOOKUP_SEP)
132+
path = []
133+
for field in keys:
134+
field = base.Options.get_field(previous, field)
135+
if isinstance(field, EmbeddedModelField):
136+
previous = field.embedded_model._meta
137+
else:
138+
previous = field
139+
path.append(field.column)
140+
column = ".".join(path)
141+
embedded_column = field.clone()
142+
embedded_column.column = column
143+
return embedded_column
144+
return super().get_field(field_name)
145+
146+
124147
def register_lookups():
125148
BuiltinLookup.as_mql = builtin_lookup
126149
FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = (
@@ -131,3 +154,4 @@ def register_lookups():
131154
IsNull.as_mql = is_null
132155
PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value
133156
UUIDTextMixin.as_mql = uuid_text_mixin
157+
base.Options = Options

django_mongodb_backend/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .fields import EmbeddedModelField
88
from .gis.schema import GISSchemaEditor
99
from .query import wrap_database_errors
10-
from .utils import OperationCollector
10+
from .utils import OperationCollector, get_column_path
1111

1212

1313
def ignore_embedded_models(func):
@@ -249,7 +249,7 @@ def alter_unique_together(
249249
)
250250
# Created uniques
251251
for field_names in news.difference(olds):
252-
columns = [model._meta.get_field(field).column for field in field_names]
252+
columns = [get_column_path(model, field).column for field in field_names]
253253
name = str(
254254
self._unique_constraint_name(
255255
model._meta.db_table, [column_prefix + col for col in columns]

django_mongodb_backend/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.conf import settings
66
from django.core.exceptions import ImproperlyConfigured, ValidationError
77
from django.db.backends.utils import logger
8+
from django.db.models.constants import LOOKUP_SEP
89
from django.utils.functional import SimpleLazyObject
910
from django.utils.text import format_lazy
1011
from django.utils.version import get_version_tuple
@@ -186,3 +187,21 @@ def wrapper(self, *args, **kwargs):
186187
self.log(method, args, kwargs)
187188

188189
return wrapper
190+
191+
192+
def get_column_path(model, field_name):
193+
from .fields import EmbeddedModelField # noqa: PLC0415
194+
195+
if LOOKUP_SEP in field_name:
196+
previous = model
197+
keys = field_name.split(LOOKUP_SEP)
198+
path = []
199+
for field in keys:
200+
field = previous._meta.get_field(field)
201+
previous = field.embedded_model if isinstance(field, EmbeddedModelField) else field
202+
path.append(field.column)
203+
column = ".".join(path)
204+
embedded_column = field.clone()
205+
embedded_column.column = column
206+
return embedded_column
207+
return model._meta.get_field(field_name)

tests/schema_/test_embedded_model.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,168 @@ class Meta:
519519
self.assertTableNotExists(Author)
520520

521521

522+
class EmbeddedModelsTopLevelIndexTest(TestMixin, TransactionTestCase):
523+
@isolate_apps("schema_")
524+
def test_unique_together(self):
525+
"""Meta.unique_together defined at the top-level for embedded fields."""
526+
527+
class Address(EmbeddedModel):
528+
unique_together_one = models.CharField(max_length=10)
529+
unique_together_two = models.CharField(max_length=10)
530+
531+
class Meta:
532+
app_label = "schema_"
533+
534+
class Author(EmbeddedModel):
535+
address = EmbeddedModelField(Address)
536+
unique_together_three = models.CharField(max_length=10)
537+
unique_together_four = models.CharField(max_length=10)
538+
539+
class Meta:
540+
app_label = "schema_"
541+
542+
class Book(models.Model):
543+
author = EmbeddedModelField(Author)
544+
545+
class Meta:
546+
app_label = "schema_"
547+
unique_together = [
548+
("author__unique_together_three", "author__unique_together_four"),
549+
(
550+
"author__address__unique_together_one",
551+
"author__address__unique_together_two",
552+
),
553+
]
554+
555+
with connection.schema_editor() as editor:
556+
editor.create_model(Book)
557+
self.assertTableExists(Book)
558+
# Embedded uniques are created from top-level definition.
559+
self.assertEqual(
560+
self.get_constraints_for_columns(
561+
Book, ["author.unique_together_three", "author.unique_together_four"]
562+
),
563+
[
564+
"schema__book_author.unique_together_three_author.unique_together_four_09a570b8_uniq"
565+
],
566+
)
567+
self.assertEqual(
568+
self.get_constraints_for_columns(
569+
Book,
570+
["author.address.unique_together_one", "author.address.unique_together_two"],
571+
),
572+
[
573+
"schema__book_author.address.unique_together_one_author.address.unique_together_two_2c2d1477_uniq"
574+
],
575+
)
576+
editor.delete_model(Book)
577+
self.assertTableNotExists(Book)
578+
579+
@isolate_apps("schema_")
580+
def test_add_remove_field_indexes(self):
581+
"""AddField/RemoveField + EmbeddedModelField + Meta.indexes at top-level."""
582+
583+
class Address(EmbeddedModel):
584+
indexed_one = models.CharField(max_length=10)
585+
586+
class Meta:
587+
app_label = "schema_"
588+
589+
class Author(EmbeddedModel):
590+
address = EmbeddedModelField(Address)
591+
indexed_two = models.CharField(max_length=10)
592+
593+
class Meta:
594+
app_label = "schema_"
595+
596+
class Book(models.Model):
597+
author = EmbeddedModelField(Author)
598+
599+
class Meta:
600+
app_label = "schema_"
601+
indexes = [
602+
models.Index(fields=["author__indexed_two"]),
603+
models.Index(fields=["author__address__indexed_one"]),
604+
]
605+
606+
new_field = EmbeddedModelField(Author)
607+
new_field.set_attributes_from_name("author")
608+
609+
with connection.schema_editor() as editor:
610+
# Create the table and add the field.
611+
editor.create_model(Book)
612+
editor.add_field(Book, new_field)
613+
# Embedded indexes are created.
614+
self.assertEqual(
615+
self.get_constraints_for_columns(Book, ["author.indexed_two"]),
616+
["schema__boo_author._333c90_idx"],
617+
)
618+
self.assertEqual(
619+
self.get_constraints_for_columns(
620+
Book,
621+
["author.address.indexed_one"],
622+
),
623+
["schema__boo_author._f54386_idx"],
624+
)
625+
editor.delete_model(Book)
626+
self.assertTableNotExists(Book)
627+
628+
@isolate_apps("schema_")
629+
def test_add_remove_field_constraints(self):
630+
"""AddField/RemoveField + EmbeddedModelField + Meta.constraints at top-level."""
631+
632+
class Address(EmbeddedModel):
633+
unique_constraint_one = models.CharField(max_length=10)
634+
635+
class Meta:
636+
app_label = "schema_"
637+
638+
class Author(EmbeddedModel):
639+
address = EmbeddedModelField(Address)
640+
unique_constraint_two = models.CharField(max_length=10)
641+
642+
class Meta:
643+
app_label = "schema_"
644+
645+
class Book(models.Model):
646+
author = EmbeddedModelField(Author)
647+
648+
class Meta:
649+
app_label = "schema_"
650+
constraints = [
651+
models.UniqueConstraint(
652+
fields=["author__unique_constraint_two"],
653+
name="unique_two",
654+
),
655+
models.UniqueConstraint(
656+
fields=["author__address__unique_constraint_one"],
657+
name="unique_one",
658+
),
659+
]
660+
661+
new_field = EmbeddedModelField(Author)
662+
new_field.set_attributes_from_name("author")
663+
664+
with connection.schema_editor() as editor:
665+
# Create the table and add the field.
666+
editor.create_model(Book)
667+
editor.add_field(Book, new_field)
668+
# Embedded constraints are created.
669+
self.assertEqual(
670+
self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]),
671+
["unique_two"],
672+
)
673+
self.assertEqual(
674+
self.get_constraints_for_columns(
675+
Book,
676+
["author.address.unique_constraint_one"],
677+
),
678+
["unique_one"],
679+
)
680+
editor.delete_model(Book)
681+
self.assertTableNotExists(Book)
682+
683+
522684
class EmbeddedModelsIgnoredTests(TestMixin, TransactionTestCase):
523685
def test_embedded_not_created(self):
524686
"""create_model() and delete_model() ignore EmbeddedModel."""

0 commit comments

Comments
 (0)