diff --git a/mssql/schema.py b/mssql/schema.py index 29da1046..3d8f9446 100644 --- a/mssql/schema.py +++ b/mssql/schema.py @@ -4,6 +4,8 @@ import binascii import datetime +from collections import defaultdict + from django.db.backends.base.schema import ( BaseDatabaseSchemaEditor, _is_relevant_relation, @@ -92,6 +94,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_create_unique_null = "CREATE UNIQUE INDEX %(name)s ON %(table)s(%(columns)s) " \ "WHERE %(columns)s IS NOT NULL" + _deferred_unique_indexes = defaultdict(list) + def _alter_column_default_sql(self, model, old_field, new_field, drop=False): """ Hook to specialize column default alteration. @@ -279,6 +283,15 @@ def _db_table_delete_constraint_sql(self, template, db_table, name): include='' ) + def _delete_deferred_unique_indexes_for_field(self, field): + deferred_statements = self._deferred_unique_indexes.get(str(field), []) + for stmt in deferred_statements: + if stmt in self.deferred_sql: + self.deferred_sql.remove(stmt) + + def _add_deferred_unique_index_for_field(self, field, statement): + self._deferred_unique_indexes[str(field)].append(statement) + def _alter_field(self, model, old_field, new_field, old_type, new_type, old_db_params, new_db_params, strict=False): """Actually perform a "physical" (non-ManyToMany) field update.""" @@ -542,6 +555,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, self.execute(self._create_unique_sql(model, [new_field])) else: self.execute(self._create_unique_sql(model, [new_field.column])) + self._delete_deferred_unique_indexes_for_field(new_field) # Added an index? # constraint will no longer be used in lieu of an index. The following # lines from the truth table show all True cases; the rest are False: @@ -574,6 +588,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, self.execute(self._create_unique_sql(model, [old_field])) else: self.execute(self._create_unique_sql(model, columns=[old_field.column])) + self._delete_deferred_unique_indexes_for_field(old_field) else: if django_version >= (4, 0): for field_names in model._meta.unique_together: @@ -846,9 +861,11 @@ def add_field(self, model, field): not field.many_to_many and field.null and field.unique): definition = definition.replace(' UNIQUE', '') - self.deferred_sql.append(self._create_index_sql( + statement = self._create_index_sql( model, [field], sql=self.sql_create_unique_null, suffix="_uniq" - )) + ) + self.deferred_sql.append(statement) + self._add_deferred_unique_index_for_field(field, statement) # Check constraints can go on the column SQL here db_params = field.db_parameters(connection=self.connection) @@ -1012,9 +1029,11 @@ def create_model(self, model): not field.many_to_many and field.null and field.unique): definition = definition.replace(' UNIQUE', '') - self.deferred_sql.append(self._create_index_sql( + statement = self._create_index_sql( model, [field], sql=self.sql_create_unique_null, suffix="_uniq" - )) + ) + self.deferred_sql.append(statement) + self._add_deferred_unique_index_for_field(field, statement) # Check constraints can go on the column SQL here db_params = field.db_parameters(connection=self.connection) diff --git a/testapp/tests/test_indexes.py b/testapp/tests/test_indexes.py index 40100f6b..53e7ec38 100644 --- a/testapp/tests/test_indexes.py +++ b/testapp/tests/test_indexes.py @@ -3,7 +3,9 @@ import django.db from django import VERSION from django.apps import apps -from django.db import models +from django.db import models, migrations +from django.db.migrations.migration import Migration +from django.db.migrations.state import ProjectState from django.db.models import UniqueConstraint from django.db.utils import DEFAULT_DB_ALIAS, ConnectionHandler, ProgrammingError from django.test import TestCase @@ -175,3 +177,36 @@ def test_unique_index_dropped(self): editor.alter_field(Choice, old_field, new_field, strict=True) except ProgrammingError: self.fail("Unique indexes not being dropped") + +class TestAddAndAlterUniqueIndex(TestCase): + + def test_alter_unique_nullable_to_non_nullable(self): + """ + Test a single migration that creates a field with unique=True and null=True and then alters + the field to set null=False. See https://github.com/microsoft/mssql-django/issues/22 + """ + operations = [ + migrations.CreateModel( + "TestAlterNullableInUniqueField", + [ + ("id", models.AutoField(primary_key=True)), + ("a", models.CharField(max_length=4, unique=True, null=True)), + ] + ), + migrations.AlterField( + "testalternullableinuniquefield", + "a", + models.CharField(max_length=4, unique=True) + ) + ] + + project_state = ProjectState() + new_state = project_state.clone() + migration = Migration("name", "testapp") + migration.operations = operations + + try: + with connection.schema_editor(atomic=True) as editor: + migration.apply(new_state, editor) + except django.db.utils.ProgrammingError as e: + self.fail('Check if can alter field from unique, nullable to unique non-nullable for issue #23, AlterField failed with exception: %s' % e)