|
| 1 | +import itertools |
| 2 | + |
| 3 | +from django.db import ( |
| 4 | + connection, |
| 5 | +) |
| 6 | +from django.test import ( |
| 7 | + TransactionTestCase, |
| 8 | +) |
| 9 | + |
| 10 | +from .models import Address, Author, Book, new_apps |
| 11 | + |
| 12 | + |
| 13 | +class SchemaTests(TransactionTestCase): |
| 14 | + available_apps = [] |
| 15 | + models = [Address, Author, Book] |
| 16 | + |
| 17 | + # Utility functions |
| 18 | + |
| 19 | + def setUp(self): |
| 20 | + # local_models should contain test dependent model classes that will be |
| 21 | + # automatically removed from the app cache on test tear down. |
| 22 | + self.local_models = [] |
| 23 | + # isolated_local_models contains models that are in test methods |
| 24 | + # decorated with @isolate_apps. |
| 25 | + self.isolated_local_models = [] |
| 26 | + |
| 27 | + def tearDown(self): |
| 28 | + # Delete any tables made for our models |
| 29 | + self.delete_tables() |
| 30 | + new_apps.clear_cache() |
| 31 | + for model in new_apps.get_models(): |
| 32 | + model._meta._expire_cache() |
| 33 | + if "schema" in new_apps.all_models: |
| 34 | + for model in self.local_models: |
| 35 | + for many_to_many in model._meta.many_to_many: |
| 36 | + through = many_to_many.remote_field.through |
| 37 | + if through and through._meta.auto_created: |
| 38 | + del new_apps.all_models["schema"][through._meta.model_name] |
| 39 | + del new_apps.all_models["schema"][model._meta.model_name] |
| 40 | + if self.isolated_local_models: |
| 41 | + with connection.schema_editor() as editor: |
| 42 | + for model in self.isolated_local_models: |
| 43 | + editor.delete_model(model) |
| 44 | + |
| 45 | + def delete_tables(self): |
| 46 | + "Deletes all model tables for our models for a clean test environment" |
| 47 | + converter = connection.introspection.identifier_converter |
| 48 | + with connection.schema_editor() as editor: |
| 49 | + connection.disable_constraint_checking() |
| 50 | + table_names = connection.introspection.table_names() |
| 51 | + if connection.features.ignores_table_name_case: |
| 52 | + table_names = [table_name.lower() for table_name in table_names] |
| 53 | + for model in itertools.chain(SchemaTests.models, self.local_models): |
| 54 | + tbl = converter(model._meta.db_table) |
| 55 | + if connection.features.ignores_table_name_case: |
| 56 | + tbl = tbl.lower() |
| 57 | + if tbl in table_names: |
| 58 | + editor.delete_model(model) |
| 59 | + table_names.remove(tbl) |
| 60 | + connection.enable_constraint_checking() |
| 61 | + |
| 62 | + def column_classes(self, model): |
| 63 | + with connection.cursor() as cursor: |
| 64 | + columns = { |
| 65 | + d[0]: (connection.introspection.get_field_type(d[1], d), d) |
| 66 | + for d in connection.introspection.get_table_description( |
| 67 | + cursor, |
| 68 | + model._meta.db_table, |
| 69 | + ) |
| 70 | + } |
| 71 | + # SQLite has a different format for field_type |
| 72 | + for name, (type, desc) in columns.items(): |
| 73 | + if isinstance(type, tuple): |
| 74 | + columns[name] = (type[0], desc) |
| 75 | + return columns |
| 76 | + |
| 77 | + def get_primary_key(self, table): |
| 78 | + with connection.cursor() as cursor: |
| 79 | + return connection.introspection.get_primary_key_column(cursor, table) |
| 80 | + |
| 81 | + def get_indexes(self, table): |
| 82 | + """ |
| 83 | + Get the indexes on the table using a new cursor. |
| 84 | + """ |
| 85 | + with connection.cursor() as cursor: |
| 86 | + return [ |
| 87 | + c["columns"][0] |
| 88 | + for c in connection.introspection.get_constraints(cursor, table).values() |
| 89 | + if c["index"] and len(c["columns"]) == 1 |
| 90 | + ] |
| 91 | + |
| 92 | + def get_uniques(self, table): |
| 93 | + with connection.cursor() as cursor: |
| 94 | + return [ |
| 95 | + c["columns"][0] |
| 96 | + for c in connection.introspection.get_constraints(cursor, table).values() |
| 97 | + if c["unique"] and len(c["columns"]) == 1 |
| 98 | + ] |
| 99 | + |
| 100 | + def get_constraints(self, table): |
| 101 | + """ |
| 102 | + Get the constraints on a table using a new cursor. |
| 103 | + """ |
| 104 | + with connection.cursor() as cursor: |
| 105 | + return connection.introspection.get_constraints(cursor, table) |
| 106 | + |
| 107 | + def get_constraints_for_column(self, model, column_name): |
| 108 | + constraints = self.get_constraints(model._meta.db_table) |
| 109 | + constraints_for_column = [] |
| 110 | + for name, details in constraints.items(): |
| 111 | + if details["columns"] == [column_name]: |
| 112 | + constraints_for_column.append(name) |
| 113 | + return sorted(constraints_for_column) |
| 114 | + |
| 115 | + def get_constraint_opclasses(self, constraint_name): |
| 116 | + with connection.cursor() as cursor: |
| 117 | + sql = """ |
| 118 | + SELECT opcname |
| 119 | + FROM pg_opclass AS oc |
| 120 | + JOIN pg_index as i on oc.oid = ANY(i.indclass) |
| 121 | + JOIN pg_class as c on c.oid = i.indexrelid |
| 122 | + WHERE c.relname = %s |
| 123 | + """ |
| 124 | + cursor.execute(sql, [constraint_name]) |
| 125 | + return [row[0] for row in cursor.fetchall()] |
| 126 | + |
| 127 | + def check_added_field_default( |
| 128 | + self, |
| 129 | + schema_editor, |
| 130 | + model, |
| 131 | + field, |
| 132 | + field_name, |
| 133 | + expected_default, |
| 134 | + cast_function=None, |
| 135 | + ): |
| 136 | + schema_editor.add_field(model, field) |
| 137 | + database_default = connection.database[model._meta.db_table].find_one().get(field_name) |
| 138 | + # cursor.execute( |
| 139 | + # "SELECT {} FROM {};".format(field_name, model._meta.db_table) |
| 140 | + # ) |
| 141 | + # database_default = cursor.fetchall()[0][0] |
| 142 | + if cast_function and type(database_default) is not type(expected_default): |
| 143 | + database_default = cast_function(database_default) |
| 144 | + self.assertEqual(database_default, expected_default) |
| 145 | + |
| 146 | + def get_constraints_count(self, table, column, fk_to): |
| 147 | + """ |
| 148 | + Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the |
| 149 | + number of foreign keys, unique constraints, and indexes on |
| 150 | + `table`.`column`. The `fk_to` argument is a 2-tuple specifying the |
| 151 | + expected foreign key relationship's (table, column). |
| 152 | + """ |
| 153 | + with connection.cursor() as cursor: |
| 154 | + constraints = connection.introspection.get_constraints(cursor, table) |
| 155 | + counts = {"fks": 0, "uniques": 0, "indexes": 0} |
| 156 | + for c in constraints.values(): |
| 157 | + if c["columns"] == [column]: |
| 158 | + if c["foreign_key"] == fk_to: |
| 159 | + counts["fks"] += 1 |
| 160 | + if c["unique"]: |
| 161 | + counts["uniques"] += 1 |
| 162 | + elif c["index"]: |
| 163 | + counts["indexes"] += 1 |
| 164 | + return counts |
| 165 | + |
| 166 | + def get_column_collation(self, table, column): |
| 167 | + with connection.cursor() as cursor: |
| 168 | + return next( |
| 169 | + f.collation |
| 170 | + for f in connection.introspection.get_table_description(cursor, table) |
| 171 | + if f.name == column |
| 172 | + ) |
| 173 | + |
| 174 | + def get_column_comment(self, table, column): |
| 175 | + with connection.cursor() as cursor: |
| 176 | + return next( |
| 177 | + f.comment |
| 178 | + for f in connection.introspection.get_table_description(cursor, table) |
| 179 | + if f.name == column |
| 180 | + ) |
| 181 | + |
| 182 | + def get_table_comment(self, table): |
| 183 | + with connection.cursor() as cursor: |
| 184 | + return next( |
| 185 | + t.comment |
| 186 | + for t in connection.introspection.get_table_list(cursor) |
| 187 | + if t.name == table |
| 188 | + ) |
| 189 | + |
| 190 | + def assert_column_comment_not_exists(self, table, column): |
| 191 | + with connection.cursor() as cursor: |
| 192 | + columns = connection.introspection.get_table_description(cursor, table) |
| 193 | + self.assertFalse(any(c.name == column and c.comment for c in columns)) |
| 194 | + |
| 195 | + def assertIndexOrder(self, table, index, order): |
| 196 | + constraints = self.get_constraints(table) |
| 197 | + self.assertIn(index, constraints) |
| 198 | + index_orders = constraints[index]["orders"] |
| 199 | + self.assertTrue( |
| 200 | + all(val == expected for val, expected in zip(index_orders, order, strict=True)) |
| 201 | + ) |
| 202 | + |
| 203 | + def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"): |
| 204 | + """ |
| 205 | + Fail if the FK constraint on `model.Meta.db_table`.`column` to |
| 206 | + `expected_fk_table`.id doesn't exist. |
| 207 | + """ |
| 208 | + if not connection.features.can_introspect_foreign_keys: |
| 209 | + return |
| 210 | + constraints = self.get_constraints(model._meta.db_table) |
| 211 | + constraint_fk = None |
| 212 | + for details in constraints.values(): |
| 213 | + if details["columns"] == [column] and details["foreign_key"]: |
| 214 | + constraint_fk = details["foreign_key"] |
| 215 | + break |
| 216 | + self.assertEqual(constraint_fk, (expected_fk_table, field)) |
| 217 | + |
| 218 | + def assertForeignKeyNotExists(self, model, column, expected_fk_table): |
| 219 | + if not connection.features.can_introspect_foreign_keys: |
| 220 | + return |
| 221 | + with self.assertRaises(AssertionError): |
| 222 | + self.assertForeignKeyExists(model, column, expected_fk_table) |
| 223 | + |
| 224 | + def assertTableExists(self, model): |
| 225 | + self.assertIn(model._meta.db_table, connection.introspection.table_names()) |
| 226 | + |
| 227 | + def assertTableNotExists(self, model): |
| 228 | + self.assertNotIn(model._meta.db_table, connection.introspection.table_names()) |
| 229 | + |
| 230 | + # Tests |
| 231 | + def test_embedded_index(self): |
| 232 | + """db_index on an embedded model.""" |
| 233 | + with connection.schema_editor() as editor: |
| 234 | + # Create the table |
| 235 | + editor.create_model(Book) |
| 236 | + # The table is there |
| 237 | + self.assertTableExists(Book) |
| 238 | + # Embedded indexes are created. |
| 239 | + self.assertEqual( |
| 240 | + self.get_constraints_for_column(Book, "author.age"), |
| 241 | + ["schema__book_author.age_dc08100b"], |
| 242 | + ) |
| 243 | + self.assertEqual( |
| 244 | + self.get_constraints_for_column(Book, "author.address.zip_code"), |
| 245 | + ["schema__book_author.address.zip_code_7b9a9307"], |
| 246 | + ) |
| 247 | + # Clean up that table |
| 248 | + editor.delete_model(Author) |
| 249 | + # Indexes are gone. |
| 250 | + self.assertEqual( |
| 251 | + self.get_constraints_for_column(Author, "author.address.zip_code"), |
| 252 | + [], |
| 253 | + ) |
| 254 | + # The table is gone |
| 255 | + self.assertTableNotExists(Author) |
0 commit comments