Skip to content

Use Mypy's strict mode #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
148d29c
Use Mypy's strict mode
adamchainz Aug 15, 2022
22fc94a
Finish hints for cache module
adamchainz Aug 22, 2022
5565211
Finish hints for utils
adamchainz Aug 22, 2022
8434689
Finish hints for operations
adamchainz Aug 22, 2022
4efd38e
Finish hints for locks
adamchainz Aug 22, 2022
42d41ca
Finish hints for management commands
adamchainz Aug 22, 2022
f9d985a
Finish hints for status
adamchainz Aug 22, 2022
b0017d9
Stricter signature for contribute_to_class
adamchainz Aug 22, 2022
33bff0c
Fix JSONExtract output_field arg
adamchainz Aug 22, 2022
8628c09
Fix types for IndexLookup.as_sql()
adamchainz Aug 22, 2022
f48605a
Fix type for GroupConcat arg 1
adamchainz Aug 22, 2022
850369f
Fix some as_sql() signatures
adamchainz Aug 26, 2022
8d3e327
assert
adamchainz Aug 27, 2022
37e91e3
mute isinstance
adamchainz Aug 28, 2022
4995ef6
more fixes
adamchainz Aug 30, 2022
c6fd31e
Improve function types
adamchainz Oct 18, 2022
ae753be
Add cast
adamchainz Oct 18, 2022
cabb48b
Fix AsType signature
adamchainz Oct 19, 2022
f77f0ed
Fix name of pytest fixture to avoid collision
adamchainz Oct 19, 2022
2a2f298
Fix some errors in cache tests
adamchainz Oct 19, 2022
2dbe884
Fix return type of IndexLookup.as_sql
adamchainz Oct 19, 2022
c796a86
make formfield() methods return Any
adamchainz Oct 19, 2022
6d8fa46
Add extra model type asserts for deserialization tests
adamchainz Sep 9, 2024
e2006df
Some extra hints
adamchainz Sep 9, 2024
3d38989
Fix formfield() methods
adamchainz Sep 9, 2024
352054b
Allow bad arg types for tests checking that
adamchainz Sep 9, 2024
7455bcb
Pass connection to db_type()
adamchainz Sep 9, 2024
22e57fa
Correct mysql_connections()
adamchainz Sep 9, 2024
ac86029
Correct types of source expression functions
adamchainz Sep 9, 2024
b81f6b2
Upgrade django-stubs and Mypy pytest
adamchainz May 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ repos:
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
- django-stubs==5.2.0
- mysqlclient
- pytest==8.3.5
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,19 @@ enable_error_code = [
"redundant-expr",
"truthy-bool",
]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
mypy_path = "src/"
namespace_packages = false
no_implicit_optional = true
plugins = [ "mypy_django_plugin.main" ]
strict = true
warn_unreachable = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module = "tests.*"
allow_untyped_defs = true

[tool.django-stubs]
django_settings_module = "tests.settings"

[tool.rstcheck]
ignore_directives = [
"automodule",
Expand Down
18 changes: 12 additions & 6 deletions src/django_mysql/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from django.core.cache.backends.base import default_key_func
from django.db import connections
from django.db import router
from django.db.models import Model
from django.utils.encoding import force_bytes
from django.utils.module_loading import import_string

Expand Down Expand Up @@ -61,7 +62,9 @@ def __init__(self, table: str, params: dict[str, Any]) -> None:
super().__init__(params)
self._table = table

class CacheEntry:
CacheEntry: type[Model] # force Mypy to accept duck typing

class CacheEntry: # type: ignore [no-redef]
_meta = Options(table)

self.cache_model_class = CacheEntry
Expand Down Expand Up @@ -182,7 +185,7 @@ def get_many(
self, keys: Iterable[str], version: int | None = None
) -> dict[str, Any]:
made_key_to_key = {self.make_key(key, version=version): key for key in keys}
made_keys = list(made_key_to_key.keys())
made_keys: list[Any] = list(made_key_to_key.keys())
for key in made_keys:
self.validate_key(key)

Expand Down Expand Up @@ -265,7 +268,7 @@ def _base_set(
return True
else: # mode = 'add'
# Use a special code in the add query for "did insert"
insert_id = cursor.lastrowid
insert_id: int = cursor.lastrowid
return insert_id != 444

_set_many_query = collapse_spaces(
Expand Down Expand Up @@ -415,7 +418,8 @@ def _base_delta(
raise ValueError("Key '%s' not found, or not an integer" % key)

# New value stored in insert_id
return cursor.lastrowid
result: int = cursor.lastrowid
return result

# Looks a bit tangled to turn the blob back into an int for updating, but
# it works. Stores the new value for insert_id() with LAST_INSERT_ID
Expand Down Expand Up @@ -447,7 +451,7 @@ def touch(
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
with connections[db].cursor() as cursor:
affected_rows = cursor.execute(
affected_rows: int = cursor.execute(
self._touch_query.format(table=table), [exp, key, self._now()]
)
return affected_rows > 0
Expand Down Expand Up @@ -611,18 +615,20 @@ def delete_with_prefix(self, prefix: str, version: int | None = None) -> int:
prefix = self.make_key(prefix + "%", version=version)

with connections[db].cursor() as cursor:
return cursor.execute(
result: int = cursor.execute(
"""DELETE FROM {table}
WHERE cache_key LIKE %s""".format(
table=table
),
(prefix,),
)
return result

def cull(self) -> int:
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)

num_deleted: int
with connections[db].cursor() as cursor:
# First, try just deleting expired keys
num_deleted = cursor.execute(
Expand Down
6 changes: 5 additions & 1 deletion src/django_mysql/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.db import connections
from django.db.backends.utils import CursorWrapper
from django.db.models import Model
from django.db.transaction import Atomic
from django.db.transaction import TransactionManagementError
from django.db.transaction import atomic
from django.db.utils import DEFAULT_DB_ALIAS
Expand Down Expand Up @@ -77,7 +78,8 @@ def is_held(self) -> bool:
def holding_connection_id(self) -> int | None:
with self.get_cursor() as cursor:
cursor.execute("SELECT IS_USED_LOCK(%s)", (self.name,))
return cursor.fetchone()[0]
result: int | None = cursor.fetchone()[0]
return result

@classmethod
def held_with_prefix(
Expand Down Expand Up @@ -108,6 +110,7 @@ def __init__(
self.read: list[str] = self._process_names(read)
self.write: list[str] = self._process_names(write)
self.db = DEFAULT_DB_ALIAS if using is None else using
self._atomic: Atomic | None = None

def _process_names(self, names: list[str | type[Model]] | None) -> list[str]:
"""
Expand Down Expand Up @@ -170,6 +173,7 @@ def release(
) -> None:
connection = connections[self.db]
with connection.cursor() as cursor:
assert self._atomic is not None
self._atomic.__exit__(exc_type, exc_value, exc_traceback)
self._atomic = None
cursor.execute("UNLOCK TABLES")
7 changes: 4 additions & 3 deletions src/django_mysql/management/commands/cull_mysql_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
help="Specify the cache alias(es) to cull.",
)

def handle(
self, *args: Any, verbosity: int, aliases: list[str], **options: Any
) -> None:
def handle(self, *args: Any, **options: Any) -> None:
verbosity: int = options["verbosity"]
aliases: list[str] = options["aliases"]

if not aliases:
aliases = list(settings.CACHES)

Expand Down
8 changes: 5 additions & 3 deletions src/django_mysql/management/commands/dbparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"pt-online-schema-change $(./manage.py dbparams --dsn)",
)

def handle(
self, *args: Any, alias: str, show_mysql: bool, show_dsn: bool, **options: Any
) -> None:
def handle(self, *args: Any, **options: Any) -> None:
alias: str = options["alias"]
show_mysql: bool = options["show_mysql"]
show_dsn: bool = options["show_dsn"]

try:
connection = connections[alias]
except ConnectionDoesNotExist:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
help="Specify the cache alias(es) to create migrations for.",
)

def handle(self, *args: Any, aliases: list[str], **options: Any) -> None:
def handle(self, *args: Any, **options: Any) -> None:
aliases: list[str] = options["aliases"]
if not aliases:
aliases = list(settings.CACHES)

Expand Down
30 changes: 29 additions & 1 deletion src/django_mysql/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django_mysql.models.aggregates import BitOr
from django_mysql.models.aggregates import BitXor
from django_mysql.models.aggregates import GroupConcat
from django_mysql.models.base import Model # noqa
from django_mysql.models.base import Model
from django_mysql.models.expressions import ListF
from django_mysql.models.expressions import SetF
from django_mysql.models.fields import Bit1BooleanField
Expand All @@ -27,3 +27,31 @@
from django_mysql.models.query import SmartIterator
from django_mysql.models.query import add_QuerySetMixin
from django_mysql.models.query import pt_visual_explain

__all__ = (
"add_QuerySetMixin",
"ApproximateInt",
"Bit1BooleanField",
"BitAnd",
"BitOr",
"BitXor",
"DynamicField",
"EnumField",
"FixedCharField",
"GroupConcat",
"ListCharField",
"ListF",
"ListTextField",
"Model",
"NullBit1BooleanField",
"pt_visual_explain",
"QuerySet",
"QuerySetMixin",
"SetCharField",
"SetF",
"SetTextField",
"SizedBinaryField",
"SizedTextField",
"SmartChunkedIterator",
"SmartIterator",
)
3 changes: 1 addition & 2 deletions src/django_mysql/models/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Aggregate
from django.db.models import CharField
from django.db.models import Expression
from django.db.models.sql.compiler import SQLCompiler


Expand All @@ -29,7 +28,7 @@ class GroupConcat(Aggregate):

def __init__(
self,
expression: Expression,
expression: Any,
distinct: bool = False,
separator: str | None = None,
ordering: str | None = None,
Expand Down
15 changes: 9 additions & 6 deletions src/django_mysql/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from collections.abc import Iterable
from typing import Any
from typing import Sequence

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import F
from django.db.models import Value
from django.db.models.expressions import BaseExpression
from django.db.models.expressions import Combinable
from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler

from django_mysql.utils import collapse_spaces
Expand All @@ -18,10 +21,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None:
self.lhs = lhs
self.rhs = rhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs, self.rhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
self.lhs, self.rhs = exprs


Expand Down Expand Up @@ -138,10 +141,10 @@ def __init__(self, lhs: BaseExpression) -> None:
super().__init__()
self.lhs = lhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
(self.lhs,) = exprs

def as_sql(
Expand Down Expand Up @@ -170,10 +173,10 @@ def __init__(self, lhs: BaseExpression) -> None:
super().__init__()
self.lhs = lhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
(self.lhs,) = exprs

def as_sql(
Expand Down
4 changes: 2 additions & 2 deletions src/django_mysql/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from django_mysql.models.fields.tiny_integer import PositiveTinyIntegerField
from django_mysql.models.fields.tiny_integer import TinyIntegerField

__all__ = [
__all__ = (
"Bit1BooleanField",
"DynamicField",
"EnumField",
Expand All @@ -28,4 +28,4 @@
"SizedBinaryField",
"SizedTextField",
"TinyIntegerField",
]
)
14 changes: 7 additions & 7 deletions src/django_mysql/models/fields/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Union
from typing import cast

from django import forms
from django.core import checks
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import DateField
Expand All @@ -20,12 +21,11 @@
from django.db.models import TimeField
from django.db.models import Transform
from django.db.models.sql.compiler import SQLCompiler
from django.forms import Field as FormField
from django.utils.translation import gettext_lazy as _

from django_mysql.checks import mysql_connections
from django_mysql.models.lookups import DynColHasKey
from django_mysql.typing import DeconstructResult
from django_mysql.utils import mysql_connections

try:
import mariadb_dyncol
Expand Down Expand Up @@ -87,7 +87,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]:
return errors

def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []
if mariadb_dyncol is None:
errors.append(
checks.Error(
Expand All @@ -100,7 +100,7 @@ def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]:
return errors

def _check_mariadb_version(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

any_conn_works = any(
(conn.vendor == "mysql" and conn.mysql_is_mariadb)
Expand All @@ -119,7 +119,7 @@ def _check_mariadb_version(self) -> list[checks.CheckMessage]:
return errors

def _check_character_set(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

conn = None
for _alias, check_conn in mysql_connections():
Expand Down Expand Up @@ -151,7 +151,7 @@ def _check_character_set(self) -> list[checks.CheckMessage]:
def _check_spec_recursively(
self, spec: Any, path: str = ""
) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

if not isinstance(spec, dict):
errors.append(
Expand Down Expand Up @@ -290,7 +290,7 @@ def deconstruct(self) -> DeconstructResult:
kwargs["blank"] = False
return name, path, args, kwargs

def formfield(self, *args: Any, **kwargs: Any) -> FormField | None:
def formfield(self, *args: Any, **kwargs: Any) -> forms.Field | None:
"""
Disabled in forms - there is no sensible way of editing this
"""
Expand Down
Loading
Loading