Skip to content

Commit 36c177f

Browse files
committed
Improve KeyTransform initializer and types
1 parent 5019529 commit 36c177f

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

src/django_mysql/models/fields/dynamic.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,7 @@ class KeyTransform(Transform):
308308

309309
SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys()))
310310

311-
TYPE_MAP: dict[str, type[Field] | Field] = {
312-
"BINARY": DynamicField,
311+
TYPE_MAP: dict[str, Field[Any, Any]] = {
313312
"CHAR": TextField(),
314313
"DATE": DateField(),
315314
"DATETIME": DateTimeField(),
@@ -322,23 +321,26 @@ def __init__(
322321
self,
323322
key_name: str,
324323
data_type: str,
325-
*args: Any,
324+
*expressions: Any,
326325
subspec: SpecDict | None = None,
327-
**kwargs: Any,
326+
output_field: Field[Any, Any] | None = None,
327+
**extra: Any,
328328
) -> None:
329-
super().__init__(*args, **kwargs)
330-
self.key_name = key_name
331-
self.data_type = data_type
332-
333-
try:
334-
output_field = self.TYPE_MAP[data_type]
335-
except KeyError: # pragma: no cover
336-
raise ValueError(f"Invalid data_type '{data_type}'")
329+
if output_field is not None:
330+
raise ValueError("Cannot set output_field for KeyTransform")
337331

338332
if data_type == "BINARY":
339-
self.output_field = output_field(spec=subspec)
333+
output_field = DynamicField(spec=subspec)
340334
else:
341-
self.output_field = output_field
335+
try:
336+
output_field = self.TYPE_MAP[data_type]
337+
except KeyError:
338+
raise ValueError(f"Invalid data_type {data_type!r}")
339+
340+
super().__init__(*expressions, output_field=output_field, **extra)
341+
342+
self.key_name = key_name
343+
self.data_type = data_type
342344

343345
def as_sql(
344346
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper

tests/testapp/test_dynamicfield.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from django.test.utils import isolate_apps
1616

1717
from django_mysql.models import DynamicField
18+
from django_mysql.models.fields.dynamic import KeyTransform
1819
from tests.testapp.models import DynamicModel, SpeclessDynamicModel
1920

2021

@@ -159,6 +160,18 @@ def test_non_existent_transform(self):
159160
def test_has_key(self):
160161
assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3]
161162

163+
def test_key_transform_initialize_output_field(self):
164+
with pytest.raises(ValueError) as excinfo:
165+
KeyTransform("x", "y", output_field=CharField())
166+
167+
assert str(excinfo.value) == "Cannot set output_field for KeyTransform"
168+
169+
def test_key_transform_initialize_bad_type(self):
170+
with pytest.raises(ValueError) as excinfo:
171+
KeyTransform("x", "unknown")
172+
173+
assert str(excinfo.value) == "Invalid data_type 'unknown'"
174+
162175
def test_key_transform_datey(self):
163176
assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [
164177
self.objs[4]

0 commit comments

Comments
 (0)