Skip to content

Improve KeyTransform initializer and types #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions src/django_mysql/models/fields/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,18 @@ def _check_spec_recursively(
subpath = f"{path}.{key}"
errors.extend(self._check_spec_recursively(value, subpath))
elif value not in KeyTransform.SPEC_MAP:
valid_names = ", ".join(
sorted(x.__name__ for x in KeyTransform.SPEC_MAP.keys())
)
errors.append(
checks.Error(
"The value for '{}' in 'spec{}' is not an allowed type".format(
key, path
),
hint="'spec' values must be one of the following "
"types: {}".format(KeyTransform.SPEC_MAP_NAMES),
hint=(
"'spec' values must be one of the following types: "
+ valid_names
),
obj=self,
id="django_mysql.E011",
)
Expand Down Expand Up @@ -306,10 +311,8 @@ class KeyTransform(Transform):
dict: "BINARY",
}

SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys()))

TYPE_MAP: dict[str, type[Field] | Field] = {
"BINARY": DynamicField,
TYPE_MAP: dict[str, Field[Any, Any]] = {
# Excludes BINARY -> DynamicField as that’s requires spec
"CHAR": TextField(),
"DATE": DateField(),
"DATETIME": DateTimeField(),
Expand All @@ -322,23 +325,22 @@ def __init__(
self,
key_name: str,
data_type: str,
*args: Any,
*expressions: Any,
subspec: SpecDict | None = None,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.key_name = key_name
self.data_type = data_type

try:
output_field = self.TYPE_MAP[data_type]
except KeyError: # pragma: no cover
raise ValueError(f"Invalid data_type '{data_type}'")

output_field: Field[Any, Any]
if data_type == "BINARY":
self.output_field = output_field(spec=subspec)
output_field = DynamicField(spec=subspec)
else:
self.output_field = output_field
try:
output_field = self.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type {data_type!r}")

super().__init__(*expressions, output_field=output_field)

self.key_name = key_name
self.data_type = data_type

def as_sql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
Expand Down
33 changes: 13 additions & 20 deletions src/django_mysql/models/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,20 +441,16 @@ class AsType(Func):
template = "%(expressions)s AS %(data_type)s"

def __init__(self, expression: ExpressionArgument, data_type: str) -> None:
from django_mysql.models.fields.dynamic import KeyTransform

if not hasattr(expression, "resolve_expression"):
expression = Value(expression)

if data_type not in self.TYPE_MAP:
if data_type not in KeyTransform.TYPE_MAP and data_type != "BINARY":
raise ValueError(f"Invalid data_type '{data_type}'")

super().__init__(expression, data_type=data_type)

@property
def TYPE_MAP(self) -> dict[str, type[DjangoField] | DjangoField]:
from django_mysql.models.fields.dynamic import KeyTransform

return KeyTransform.TYPE_MAP


class ColumnAdd(Func):
function = "COLUMN_ADD"
Expand Down Expand Up @@ -508,25 +504,22 @@ def __init__(
self,
expression: ExpressionArgument,
column_name: ExpressionArgument,
data_type: ExpressionArgument,
data_type: str,
):
from django_mysql.models.fields.dynamic import DynamicField, KeyTransform

if not hasattr(column_name, "resolve_expression"):
column_name = Value(column_name)

try:
output_field = self.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type '{data_type}'")

output_field: DjangoField[Any, Any]
if data_type == "BINARY":
output_field = output_field()
output_field = DynamicField()
else:
try:
output_field = KeyTransform.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type {data_type!r}")

super().__init__(
expression, column_name, output_field=output_field, data_type=data_type
)

@property
def TYPE_MAP(self) -> dict[str, DjangoField | type[DjangoField]]:
from django_mysql.models.fields.dynamic import KeyTransform

return KeyTransform.TYPE_MAP
7 changes: 7 additions & 0 deletions tests/testapp/test_dynamicfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.test.utils import isolate_apps

from django_mysql.models import DynamicField
from django_mysql.models.fields.dynamic import KeyTransform
from tests.testapp.models import DynamicModel, SpeclessDynamicModel


Expand Down Expand Up @@ -159,6 +160,12 @@ def test_non_existent_transform(self):
def test_has_key(self):
assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3]

def test_key_transform_initialize_bad_type(self):
with pytest.raises(ValueError) as excinfo:
KeyTransform("x", "unknown")

assert str(excinfo.value) == "Invalid data_type 'unknown'"

def test_key_transform_datey(self):
assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [
self.objs[4]
Expand Down