Skip to content

Commit fb9322b

Browse files
authored
Improve KeyTransform initializer and types (#940)
1 parent 5019529 commit fb9322b

File tree

3 files changed

+41
-39
lines changed

3 files changed

+41
-39
lines changed

src/django_mysql/models/fields/dynamic.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,18 @@ def _check_spec_recursively(
179179
subpath = f"{path}.{key}"
180180
errors.extend(self._check_spec_recursively(value, subpath))
181181
elif value not in KeyTransform.SPEC_MAP:
182+
valid_names = ", ".join(
183+
sorted(x.__name__ for x in KeyTransform.SPEC_MAP.keys())
184+
)
182185
errors.append(
183186
checks.Error(
184187
"The value for '{}' in 'spec{}' is not an allowed type".format(
185188
key, path
186189
),
187-
hint="'spec' values must be one of the following "
188-
"types: {}".format(KeyTransform.SPEC_MAP_NAMES),
190+
hint=(
191+
"'spec' values must be one of the following types: "
192+
+ valid_names
193+
),
189194
obj=self,
190195
id="django_mysql.E011",
191196
)
@@ -306,10 +311,8 @@ class KeyTransform(Transform):
306311
dict: "BINARY",
307312
}
308313

309-
SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys()))
310-
311-
TYPE_MAP: dict[str, type[Field] | Field] = {
312-
"BINARY": DynamicField,
314+
TYPE_MAP: dict[str, Field[Any, Any]] = {
315+
# Excludes BINARY -> DynamicField as that’s requires spec
313316
"CHAR": TextField(),
314317
"DATE": DateField(),
315318
"DATETIME": DateTimeField(),
@@ -322,23 +325,22 @@ def __init__(
322325
self,
323326
key_name: str,
324327
data_type: str,
325-
*args: Any,
328+
*expressions: Any,
326329
subspec: SpecDict | None = None,
327-
**kwargs: Any,
328330
) -> None:
329-
super().__init__(*args, **kwargs)
330-
self.key_name = key_name
331-
self.data_type = data_type
332-
333-
try:
334-
output_field = self.TYPE_MAP[data_type]
335-
except KeyError: # pragma: no cover
336-
raise ValueError(f"Invalid data_type '{data_type}'")
337-
331+
output_field: Field[Any, Any]
338332
if data_type == "BINARY":
339-
self.output_field = output_field(spec=subspec)
333+
output_field = DynamicField(spec=subspec)
340334
else:
341-
self.output_field = output_field
335+
try:
336+
output_field = self.TYPE_MAP[data_type]
337+
except KeyError:
338+
raise ValueError(f"Invalid data_type {data_type!r}")
339+
340+
super().__init__(*expressions, output_field=output_field)
341+
342+
self.key_name = key_name
343+
self.data_type = data_type
342344

343345
def as_sql(
344346
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper

src/django_mysql/models/functions.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -441,20 +441,16 @@ class AsType(Func):
441441
template = "%(expressions)s AS %(data_type)s"
442442

443443
def __init__(self, expression: ExpressionArgument, data_type: str) -> None:
444+
from django_mysql.models.fields.dynamic import KeyTransform
445+
444446
if not hasattr(expression, "resolve_expression"):
445447
expression = Value(expression)
446448

447-
if data_type not in self.TYPE_MAP:
449+
if data_type not in KeyTransform.TYPE_MAP and data_type != "BINARY":
448450
raise ValueError(f"Invalid data_type '{data_type}'")
449451

450452
super().__init__(expression, data_type=data_type)
451453

452-
@property
453-
def TYPE_MAP(self) -> dict[str, type[DjangoField] | DjangoField]:
454-
from django_mysql.models.fields.dynamic import KeyTransform
455-
456-
return KeyTransform.TYPE_MAP
457-
458454

459455
class ColumnAdd(Func):
460456
function = "COLUMN_ADD"
@@ -508,25 +504,22 @@ def __init__(
508504
self,
509505
expression: ExpressionArgument,
510506
column_name: ExpressionArgument,
511-
data_type: ExpressionArgument,
507+
data_type: str,
512508
):
509+
from django_mysql.models.fields.dynamic import DynamicField, KeyTransform
510+
513511
if not hasattr(column_name, "resolve_expression"):
514512
column_name = Value(column_name)
515513

516-
try:
517-
output_field = self.TYPE_MAP[data_type]
518-
except KeyError:
519-
raise ValueError(f"Invalid data_type '{data_type}'")
520-
514+
output_field: DjangoField[Any, Any]
521515
if data_type == "BINARY":
522-
output_field = output_field()
516+
output_field = DynamicField()
517+
else:
518+
try:
519+
output_field = KeyTransform.TYPE_MAP[data_type]
520+
except KeyError:
521+
raise ValueError(f"Invalid data_type {data_type!r}")
523522

524523
super().__init__(
525524
expression, column_name, output_field=output_field, data_type=data_type
526525
)
527-
528-
@property
529-
def TYPE_MAP(self) -> dict[str, DjangoField | type[DjangoField]]:
530-
from django_mysql.models.fields.dynamic import KeyTransform
531-
532-
return KeyTransform.TYPE_MAP

tests/testapp/test_dynamicfield.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from django.test.utils import isolate_apps
1616

1717
from django_mysql.models import DynamicField
18+
from django_mysql.models.fields.dynamic import KeyTransform
1819
from tests.testapp.models import DynamicModel, SpeclessDynamicModel
1920

2021

@@ -159,6 +160,12 @@ def test_non_existent_transform(self):
159160
def test_has_key(self):
160161
assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3]
161162

163+
def test_key_transform_initialize_bad_type(self):
164+
with pytest.raises(ValueError) as excinfo:
165+
KeyTransform("x", "unknown")
166+
167+
assert str(excinfo.value) == "Invalid data_type 'unknown'"
168+
162169
def test_key_transform_datey(self):
163170
assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [
164171
self.objs[4]

0 commit comments

Comments
 (0)