Skip to content

Commit cd40490

Browse files
committed
wip schema changes
1 parent 76e1bad commit cd40490

File tree

5 files changed

+313
-9
lines changed

5 files changed

+313
-9
lines changed

django_mongodb/schema.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pymongo import ASCENDING, DESCENDING
44
from pymongo.operations import IndexModel
55

6+
from .fields import EmbeddedModelField
67
from .query import wrap_database_errors
78
from .utils import OperationCollector
89

@@ -27,17 +28,27 @@ def create_model(self, model):
2728
if field.remote_field.through._meta.auto_created:
2829
self.create_model(field.remote_field.through)
2930

30-
def _create_model_indexes(self, model):
31+
def _create_model_indexes(self, model, current_path=None, parent_model=None):
3132
"""
3233
Create all indexes (field indexes & uniques, Meta.index_together,
3334
Meta.unique_together, Meta.constraints, Meta.indexes) for the model.
35+
36+
If this is a recursive call to due to an embedded model, `current_path`
37+
tracks the path that must be prepended to the index's column, and
38+
`parent_model` tracks the collection to add the index/constraint to.
3439
"""
3540
if not model._meta.managed or model._meta.proxy or model._meta.swapped:
3641
return
3742
# Field indexes and uniques
3843
for field in model._meta.local_fields:
44+
if isinstance(field, EmbeddedModelField):
45+
new_path = ".".join((current_path, field.column)) if current_path else field.column
46+
self._create_model_indexes(
47+
field.embedded_model, parent_model=parent_model or model, current_path=new_path
48+
)
3949
if self._field_should_be_indexed(model, field):
40-
self._add_field_index(model, field)
50+
column_name = f"{current_path}.{field.column}" if current_path else None
51+
self._add_field_index(parent_model or model, field, column_name=column_name)
4152
elif self._field_should_have_unique(field):
4253
self._add_field_unique(model, field)
4354
# Meta.index_together (RemovedInDjango51Warning)
@@ -162,7 +173,7 @@ def alter_unique_together(self, model, old_unique_together, new_unique_together)
162173
constraint = UniqueConstraint(fields=field_names, name=name)
163174
self.add_constraint(model, constraint)
164175

165-
def add_index(self, model, index, field=None, unique=False):
176+
def add_index(self, model, index, *, field=None, unique=False, column_name=None):
166177
if index.contains_expressions:
167178
return
168179
kwargs = {}
@@ -171,14 +182,15 @@ def add_index(self, model, index, field=None, unique=False):
171182
# Indexing on $type matches the value of most SQL databases by
172183
# allowing multiple null values for the unique constraint.
173184
if field:
174-
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
185+
column = column_name or field.column
186+
filter_expression[column] = {"$type": field.db_type(self.connection)}
175187
else:
176188
for field_name, _ in index.fields_orders:
177189
field_ = model._meta.get_field(field_name)
178190
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
179191
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
180192
index_orders = (
181-
[(field.column, ASCENDING)]
193+
[(column_name or field.column, ASCENDING)]
182194
if field
183195
else [
184196
# order is "" if ASCENDING or "DESC" if DESCENDING (see
@@ -196,11 +208,11 @@ def _add_composed_index(self, model, field_names):
196208
idx.set_name_with_model(model)
197209
self.add_index(model, idx)
198210

199-
def _add_field_index(self, model, field):
211+
def _add_field_index(self, model, field, *, column_name=None):
200212
"""Add an index on a field with db_index=True."""
201-
index = Index(fields=[field.name])
202-
index.name = self._create_index_name(model._meta.db_table, [field.column])
203-
self.add_index(model, index, field=field)
213+
index = Index(fields=[column_name or field.name])
214+
index.name = self._create_index_name(model._meta.db_table, [column_name or field.column])
215+
self.add_index(model, index, field=field, column_name=column_name)
204216

205217
def remove_index(self, model, index):
206218
if index.contains_expressions:

tests/model_fields_/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class EmbeddedModel(models.Model):
3434
class Address(models.Model):
3535
city = models.CharField(max_length=20)
3636
state = models.CharField(max_length=2)
37+
zip_code = models.IntegerField(db_index=True)
3738

3839

3940
class Author(models.Model):

tests/schema_/__init__.py

Whitespace-only changes.

tests/schema_/models.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from django.apps.registry import Apps
2+
from django.db import models
3+
4+
from django_mongodb.fields import EmbeddedModelField
5+
6+
# Because we want to test creation and deletion of these as separate things,
7+
# these models are all inserted into a separate Apps so the main test
8+
# runner doesn't migrate them.
9+
10+
new_apps = Apps()
11+
12+
13+
class Address(models.Model):
14+
city = models.CharField(max_length=20)
15+
state = models.CharField(max_length=2)
16+
zip_code = models.IntegerField(db_index=True)
17+
18+
class Meta:
19+
apps = new_apps
20+
21+
22+
class Author(models.Model):
23+
name = models.CharField(max_length=10)
24+
age = models.IntegerField(db_index=True)
25+
address = EmbeddedModelField(Address)
26+
27+
class Meta:
28+
apps = new_apps
29+
30+
31+
class Book(models.Model):
32+
name = models.CharField(max_length=100)
33+
author = EmbeddedModelField(Author)
34+
35+
class Meta:
36+
apps = new_apps

tests/schema_/test_embedded_model.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import itertools
2+
3+
from django.db import (
4+
connection,
5+
)
6+
from django.test import (
7+
TransactionTestCase,
8+
)
9+
10+
from .models import Address, Author, Book, new_apps
11+
12+
13+
class SchemaTests(TransactionTestCase):
14+
available_apps = []
15+
models = [Address, Author, Book]
16+
17+
# Utility functions
18+
19+
def setUp(self):
20+
# local_models should contain test dependent model classes that will be
21+
# automatically removed from the app cache on test tear down.
22+
self.local_models = []
23+
# isolated_local_models contains models that are in test methods
24+
# decorated with @isolate_apps.
25+
self.isolated_local_models = []
26+
27+
def tearDown(self):
28+
# Delete any tables made for our models
29+
self.delete_tables()
30+
new_apps.clear_cache()
31+
for model in new_apps.get_models():
32+
model._meta._expire_cache()
33+
if "schema" in new_apps.all_models:
34+
for model in self.local_models:
35+
for many_to_many in model._meta.many_to_many:
36+
through = many_to_many.remote_field.through
37+
if through and through._meta.auto_created:
38+
del new_apps.all_models["schema"][through._meta.model_name]
39+
del new_apps.all_models["schema"][model._meta.model_name]
40+
if self.isolated_local_models:
41+
with connection.schema_editor() as editor:
42+
for model in self.isolated_local_models:
43+
editor.delete_model(model)
44+
45+
def delete_tables(self):
46+
"Deletes all model tables for our models for a clean test environment"
47+
converter = connection.introspection.identifier_converter
48+
with connection.schema_editor() as editor:
49+
connection.disable_constraint_checking()
50+
table_names = connection.introspection.table_names()
51+
if connection.features.ignores_table_name_case:
52+
table_names = [table_name.lower() for table_name in table_names]
53+
for model in itertools.chain(SchemaTests.models, self.local_models):
54+
tbl = converter(model._meta.db_table)
55+
if connection.features.ignores_table_name_case:
56+
tbl = tbl.lower()
57+
if tbl in table_names:
58+
editor.delete_model(model)
59+
table_names.remove(tbl)
60+
connection.enable_constraint_checking()
61+
62+
def column_classes(self, model):
63+
with connection.cursor() as cursor:
64+
columns = {
65+
d[0]: (connection.introspection.get_field_type(d[1], d), d)
66+
for d in connection.introspection.get_table_description(
67+
cursor,
68+
model._meta.db_table,
69+
)
70+
}
71+
# SQLite has a different format for field_type
72+
for name, (type, desc) in columns.items():
73+
if isinstance(type, tuple):
74+
columns[name] = (type[0], desc)
75+
return columns
76+
77+
def get_primary_key(self, table):
78+
with connection.cursor() as cursor:
79+
return connection.introspection.get_primary_key_column(cursor, table)
80+
81+
def get_indexes(self, table):
82+
"""
83+
Get the indexes on the table using a new cursor.
84+
"""
85+
with connection.cursor() as cursor:
86+
return [
87+
c["columns"][0]
88+
for c in connection.introspection.get_constraints(cursor, table).values()
89+
if c["index"] and len(c["columns"]) == 1
90+
]
91+
92+
def get_uniques(self, table):
93+
with connection.cursor() as cursor:
94+
return [
95+
c["columns"][0]
96+
for c in connection.introspection.get_constraints(cursor, table).values()
97+
if c["unique"] and len(c["columns"]) == 1
98+
]
99+
100+
def get_constraints(self, table):
101+
"""
102+
Get the constraints on a table using a new cursor.
103+
"""
104+
with connection.cursor() as cursor:
105+
return connection.introspection.get_constraints(cursor, table)
106+
107+
def get_constraints_for_column(self, model, column_name):
108+
constraints = self.get_constraints(model._meta.db_table)
109+
constraints_for_column = []
110+
for name, details in constraints.items():
111+
if details["columns"] == [column_name]:
112+
constraints_for_column.append(name)
113+
return sorted(constraints_for_column)
114+
115+
def get_constraint_opclasses(self, constraint_name):
116+
with connection.cursor() as cursor:
117+
sql = """
118+
SELECT opcname
119+
FROM pg_opclass AS oc
120+
JOIN pg_index as i on oc.oid = ANY(i.indclass)
121+
JOIN pg_class as c on c.oid = i.indexrelid
122+
WHERE c.relname = %s
123+
"""
124+
cursor.execute(sql, [constraint_name])
125+
return [row[0] for row in cursor.fetchall()]
126+
127+
def check_added_field_default(
128+
self,
129+
schema_editor,
130+
model,
131+
field,
132+
field_name,
133+
expected_default,
134+
cast_function=None,
135+
):
136+
schema_editor.add_field(model, field)
137+
database_default = connection.database[model._meta.db_table].find_one().get(field_name)
138+
# cursor.execute(
139+
# "SELECT {} FROM {};".format(field_name, model._meta.db_table)
140+
# )
141+
# database_default = cursor.fetchall()[0][0]
142+
if cast_function and type(database_default) is not type(expected_default):
143+
database_default = cast_function(database_default)
144+
self.assertEqual(database_default, expected_default)
145+
146+
def get_constraints_count(self, table, column, fk_to):
147+
"""
148+
Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the
149+
number of foreign keys, unique constraints, and indexes on
150+
`table`.`column`. The `fk_to` argument is a 2-tuple specifying the
151+
expected foreign key relationship's (table, column).
152+
"""
153+
with connection.cursor() as cursor:
154+
constraints = connection.introspection.get_constraints(cursor, table)
155+
counts = {"fks": 0, "uniques": 0, "indexes": 0}
156+
for c in constraints.values():
157+
if c["columns"] == [column]:
158+
if c["foreign_key"] == fk_to:
159+
counts["fks"] += 1
160+
if c["unique"]:
161+
counts["uniques"] += 1
162+
elif c["index"]:
163+
counts["indexes"] += 1
164+
return counts
165+
166+
def get_column_collation(self, table, column):
167+
with connection.cursor() as cursor:
168+
return next(
169+
f.collation
170+
for f in connection.introspection.get_table_description(cursor, table)
171+
if f.name == column
172+
)
173+
174+
def get_column_comment(self, table, column):
175+
with connection.cursor() as cursor:
176+
return next(
177+
f.comment
178+
for f in connection.introspection.get_table_description(cursor, table)
179+
if f.name == column
180+
)
181+
182+
def get_table_comment(self, table):
183+
with connection.cursor() as cursor:
184+
return next(
185+
t.comment
186+
for t in connection.introspection.get_table_list(cursor)
187+
if t.name == table
188+
)
189+
190+
def assert_column_comment_not_exists(self, table, column):
191+
with connection.cursor() as cursor:
192+
columns = connection.introspection.get_table_description(cursor, table)
193+
self.assertFalse(any(c.name == column and c.comment for c in columns))
194+
195+
def assertIndexOrder(self, table, index, order):
196+
constraints = self.get_constraints(table)
197+
self.assertIn(index, constraints)
198+
index_orders = constraints[index]["orders"]
199+
self.assertTrue(
200+
all(val == expected for val, expected in zip(index_orders, order, strict=True))
201+
)
202+
203+
def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"):
204+
"""
205+
Fail if the FK constraint on `model.Meta.db_table`.`column` to
206+
`expected_fk_table`.id doesn't exist.
207+
"""
208+
if not connection.features.can_introspect_foreign_keys:
209+
return
210+
constraints = self.get_constraints(model._meta.db_table)
211+
constraint_fk = None
212+
for details in constraints.values():
213+
if details["columns"] == [column] and details["foreign_key"]:
214+
constraint_fk = details["foreign_key"]
215+
break
216+
self.assertEqual(constraint_fk, (expected_fk_table, field))
217+
218+
def assertForeignKeyNotExists(self, model, column, expected_fk_table):
219+
if not connection.features.can_introspect_foreign_keys:
220+
return
221+
with self.assertRaises(AssertionError):
222+
self.assertForeignKeyExists(model, column, expected_fk_table)
223+
224+
def assertTableExists(self, model):
225+
self.assertIn(model._meta.db_table, connection.introspection.table_names())
226+
227+
def assertTableNotExists(self, model):
228+
self.assertNotIn(model._meta.db_table, connection.introspection.table_names())
229+
230+
# Tests
231+
def test_embedded_index(self):
232+
"""db_index on an embedded model."""
233+
with connection.schema_editor() as editor:
234+
# Create the table
235+
editor.create_model(Book)
236+
# The table is there
237+
self.assertTableExists(Book)
238+
# Embedded indexes are created.
239+
self.assertEqual(
240+
self.get_constraints_for_column(Book, "author.age"),
241+
["schema__book_author.age_dc08100b"],
242+
)
243+
self.assertEqual(
244+
self.get_constraints_for_column(Book, "author.address.zip_code"),
245+
["schema__book_author.address.zip_code_7b9a9307"],
246+
)
247+
# Clean up that table
248+
editor.delete_model(Author)
249+
# Indexes are gone.
250+
self.assertEqual(
251+
self.get_constraints_for_column(Author, "author.address.zip_code"),
252+
[],
253+
)
254+
# The table is gone
255+
self.assertTableNotExists(Author)

0 commit comments

Comments
 (0)