Skip to content

Commit e2a57c1

Browse files
authored
Fix EnumField migrations changing choices on Django 4.1 (#935)
1 parent 282e484 commit e2a57c1

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/django_mysql/models/fields/enum.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any, cast
44

5+
import django
56
from django.db.backends.base.base import BaseDatabaseWrapper
67
from django.db.models import CharField
78
from django.utils.encoding import force_str
@@ -13,6 +14,9 @@
1314
class EnumField(CharField):
1415
description = _("Enumeration")
1516

17+
if django.VERSION >= (4, 1):
18+
non_db_attrs = tuple(f for f in CharField.non_db_attrs if f != "choices")
19+
1620
def __init__(
1721
self,
1822
*args: Any,

tests/testapp/test_enumfield.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import pytest
44
from django.core.exceptions import ValidationError
55
from django.core.management import call_command
6-
from django.db import connection
6+
from django.db import connection, models
77
from django.db.utils import DataError
88
from django.test import TestCase, TransactionTestCase, override_settings
9+
from django.test.utils import isolate_apps
910

1011
from django_mysql.models import EnumField
1112
from tests.testapp.models import EnumModel, NullableEnumModel
@@ -152,6 +153,28 @@ def test_adding_field_with_default(self):
152153
with connection.cursor() as cursor:
153154
assert table_name not in table_names(cursor)
154155

156+
@isolate_apps("tests.testapp")
157+
def test_alter_field_choices_changes(self):
158+
class Temp(models.Model):
159+
field = EnumField(choices=["apple"])
160+
161+
with connection.schema_editor() as editor:
162+
editor.create_model(Temp)
163+
164+
@self.addCleanup
165+
def drop_table():
166+
with connection.schema_editor() as editor:
167+
editor.delete_model(Temp)
168+
169+
old_field = Temp._meta.get_field("field")
170+
new_field = EnumField(choices=["apple", "banana"])
171+
new_field.set_attributes_from_name("field")
172+
173+
with connection.schema_editor() as editor, self.assertNumQueries(1):
174+
editor.alter_field(Temp, old_field, new_field, strict=True)
175+
with connection.schema_editor() as editor, self.assertNumQueries(1):
176+
editor.alter_field(Temp, new_field, old_field, strict=True)
177+
155178

156179
class TestFormfield(TestCase):
157180
def test_formfield(self):

0 commit comments

Comments
 (0)