From 8076d68ea84c6f8f479a78f7f2081524d03675ff Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Sun, 28 Aug 2022 13:21:32 +0100 Subject: [PATCH] Update tests to pass Mypy strict mode --- tests/testapp/models.py | 3 +- tests/testapp/test_aggregates.py | 5 ++ tests/testapp/test_checks.py | 4 +- tests/testapp/test_dynamicfield.py | 76 ++++++++++------- tests/testapp/test_enumfield.py | 2 +- tests/testapp/test_functions.py | 129 ++++++++++++++++------------- tests/testapp/test_locks.py | 7 +- tests/testapp/test_operations.py | 86 ++++++------------- tests/testapp/test_status.py | 12 ++- tests/testapp/test_test_utils.py | 2 +- tests/testapp/test_utils.py | 2 +- tests/testapp/utils.py | 4 +- tests/urls.py | 4 +- 13 files changed, 170 insertions(+), 166 deletions(-) diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 1e655bba..2abac8ad 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -5,6 +5,7 @@ from typing import Any import django +from django.core import checks from django.db import connection from django.db.models import ( CASCADE, @@ -119,7 +120,7 @@ class DynamicModel(Model): ) @classmethod - def check(cls, **kwargs): + def check(cls, **kwargs: Any) -> list[checks.CheckMessage]: # Disable the checks on MySQL so that checks tests don't fail if not connection.mysql_is_mariadb: return [] diff --git a/tests/testapp/test_aggregates.py b/tests/testapp/test_aggregates.py index 26dfb2d6..67182dc9 100644 --- a/tests/testapp/test_aggregates.py +++ b/tests/testapp/test_aggregates.py @@ -66,6 +66,11 @@ def test_no_rows(self): class GroupConcatTests(TestCase): + shakes: Author + jk: Author + grisham: Author + str_tutee_ids: list[str] + @classmethod def setUpTestData(cls): super().setUpTestData() diff --git a/tests/testapp/test_checks.py b/tests/testapp/test_checks.py index 7cf21b93..bb244e3c 100644 --- a/tests/testapp/test_checks.py +++ b/tests/testapp/test_checks.py @@ -9,7 +9,7 @@ class CallCheckTest(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_check(self): call_command("check", "--database", "default", "--database", "other") @@ -17,7 +17,7 @@ def test_check(self): class VariablesTests(TransactionTestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_passes(self): assert check_variables(app_configs=None, databases=["default", "other"]) == [] diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index d324de6f..0b671ebd 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -118,28 +118,31 @@ def as_sql(self, compiler, connection): class QueryTests(DynColTestCase): + objs: list[DynamicModel] + @classmethod def setUpTestData(cls): super().setUpTestData() - cls.objs = [ - DynamicModel(attrs={"a": "b"}), - DynamicModel(attrs={"a": "b", "c": "d"}), - DynamicModel(attrs={"c": "d"}), - DynamicModel(attrs={}), - DynamicModel( - attrs={ - "datetimey": dt.datetime(2001, 1, 4, 14, 15, 16), - "datey": dt.date(2001, 1, 4), - "floaty": 128.5, - "inty": 9001, - "stry": "strvalue", - "str_underscorey": "strvalue2", - "timey": dt.time(14, 15, 16), - "nesty": {"level2": "chirp"}, - } - ), - ] - DynamicModel.objects.bulk_create(cls.objs) + DynamicModel.objects.bulk_create( + [ + DynamicModel(attrs={"a": "b"}), + DynamicModel(attrs={"a": "b", "c": "d"}), + DynamicModel(attrs={"c": "d"}), + DynamicModel(attrs={}), + DynamicModel( + attrs={ + "datetimey": dt.datetime(2001, 1, 4, 14, 15, 16), + "datey": dt.date(2001, 1, 4), + "floaty": 128.5, + "inty": 9001, + "stry": "strvalue", + "str_underscorey": "strvalue2", + "timey": dt.time(14, 15, 16), + "nesty": {"level2": "chirp"}, + } + ), + ] + ) cls.objs = list(DynamicModel.objects.order_by("id")) def test_equal(self): @@ -271,14 +274,17 @@ def test_key_transform_nesty__level2__startswith(self): class SpeclessQueryTests(DynColTestCase): + objs: list[SpeclessDynamicModel] + @classmethod def setUpTestData(cls): super().setUpTestData() - objs = [ - SpeclessDynamicModel(attrs={"a": "b"}), - SpeclessDynamicModel(attrs={"a": "c"}), - ] - SpeclessDynamicModel.objects.bulk_create(objs) + SpeclessDynamicModel.objects.bulk_create( + [ + SpeclessDynamicModel(attrs={"a": "b"}), + SpeclessDynamicModel(attrs={"a": "c"}), + ] + ) cls.objs = list(SpeclessDynamicModel.objects.order_by("id")) def test_simple(self): @@ -290,7 +296,7 @@ def test_simple(self): @isolate_apps("tests.testapp") class TestCheck(DynColTestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_db_not_mariadb(self): class Valid(models.Model): @@ -337,7 +343,9 @@ class Invalid(models.Model): assert len(errors) == 1 assert errors[0].id == "django_mysql.E009" assert "'spec' must be a dict" in errors[0].msg - assert "The value passed is of type list" in errors[0].hint + hint = errors[0].hint + assert hint is not None + assert "The value passed is of type list" in hint def test_spec_key_not_valid(self): class Invalid(models.Model): @@ -347,8 +355,10 @@ class Invalid(models.Model): assert len(errors) == 1 assert errors[0].id == "django_mysql.E010" assert "The key '2.0' in 'spec' is not a string" in errors[0].msg - assert "'spec' keys must be of type " in errors[0].hint - assert "'2.0' is of type float" in errors[0].hint + hint = errors[0].hint + assert hint is not None + assert "'spec' keys must be of type " in hint + assert "'2.0' is of type float" in hint def test_spec_value_not_valid(self): class Invalid(models.Model): @@ -358,9 +368,10 @@ class Invalid(models.Model): assert len(errors) == 1 assert errors[0].id == "django_mysql.E011" assert "The value for 'bad' in 'spec' is not an allowed type" in errors[0].msg + hint = errors[0].hint + assert hint is not None assert ( - "'spec' values must be one of the following types: date, datetime" - in errors[0].hint + "'spec' values must be one of the following types: date, datetime" in hint ) def test_spec_nested_value_not_valid(self): @@ -375,9 +386,10 @@ class Invalid(models.Model): assert ( "The value for 'bad' in 'spec.l1' is not an allowed type" in errors[0].msg ) + hint = errors[0].hint + assert hint is not None assert ( - "'spec' values must be one of the following types: date, datetime" - in errors[0].hint + "'spec' values must be one of the following types: date, datetime" in hint ) diff --git a/tests/testapp/test_enumfield.py b/tests/testapp/test_enumfield.py index 1921f28a..7914eaf9 100644 --- a/tests/testapp/test_enumfield.py +++ b/tests/testapp/test_enumfield.py @@ -98,7 +98,7 @@ def test_contains_lookup(self): class TestCheck(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_check(self): errors = EnumModel.check() diff --git a/tests/testapp/test_functions.py b/tests/testapp/test_functions.py index c3e134fa..cf220845 100644 --- a/tests/testapp/test_functions.py +++ b/tests/testapp/test_functions.py @@ -128,7 +128,8 @@ def test_if_false_default_None(self): class NumericFunctionTests(TestCase): def test_crc32(self): Alphabet.objects.create(d="AAAAAA") - ab = Alphabet.objects.annotate(crc=CRC32("d")).first() + ab = Alphabet.objects.annotate(crc=CRC32("d")).get() + # Precalculated this in MySQL prompt. Python's binascii.crc32 doesn't # match - maybe sign issues? assert ab.crc == 2854018686 @@ -145,7 +146,7 @@ def test_sign(self): "csign": Sign("c"), } - ab = Alphabet.objects.annotate(**kwargs).first() + ab = Alphabet.objects.annotate(**kwargs).get() assert ab.asign == 1 assert ab.bsign == 0 @@ -155,34 +156,34 @@ def test_sign(self): class StringFunctionTests(TestCase): def test_concat_ws(self): Alphabet.objects.create(d="AAA", e="BBB") - ab = Alphabet.objects.annotate(de=ConcatWS("d", "e")).first() + ab = Alphabet.objects.annotate(de=ConcatWS("d", "e")).get() assert ab.de == "AAA,BBB" def test_concat_ws_integers(self): Alphabet.objects.create(a=1, b=2) - ab = Alphabet.objects.annotate(ab=ConcatWS("a", "b")).first() + ab = Alphabet.objects.annotate(ab=ConcatWS("a", "b")).get() assert ab.ab == "1,2" def test_concat_ws_skips_nulls(self): Alphabet.objects.create(d="AAA", e=None, f=2) - ab = Alphabet.objects.annotate(de=ConcatWS("d", "e", "f")).first() + ab = Alphabet.objects.annotate(de=ConcatWS("d", "e", "f")).get() assert ab.de == "AAA,2" def test_concat_ws_separator(self): Alphabet.objects.create(d="AAA", e="BBB") - ab = Alphabet.objects.annotate(de=ConcatWS("d", "e", separator=":")).first() + ab = Alphabet.objects.annotate(de=ConcatWS("d", "e", separator=":")).get() assert ab.de == "AAA:BBB" def test_concat_ws_separator_null_returns_none(self): Alphabet.objects.create(a=1, b=2) concat = ConcatWS("a", "b", separator=None) - ab = Alphabet.objects.annotate(ab=concat).first() + ab = Alphabet.objects.annotate(ab=concat).get() assert ab.ab is None def test_concat_ws_separator_field(self): Alphabet.objects.create(a=1, d="AAA", e="BBB") concat = ConcatWS("d", "e", separator=F("a")) - ab = Alphabet.objects.annotate(de=concat).first() + ab = Alphabet.objects.annotate(de=concat).get() assert ab.de == "AAA1BBB" def test_concat_ws_too_few_fields(self): @@ -196,7 +197,7 @@ def test_concat_ws_then_lookups_from_textfield(self): ab = ( Alphabet.objects.annotate(de=ConcatWS("d", "e", separator=":")) .filter(de__endswith=":BBB") - .first() + .get() ) assert ab.de == "AAA:BBB" @@ -216,16 +217,16 @@ def test_elt_expression(self): def test_field_simple(self): Alphabet.objects.create(d="a") - ab = Alphabet.objects.annotate(dp=Field("d", ["a", "b"])).first() + ab = Alphabet.objects.annotate(dp=Field("d", ["a", "b"])).get() assert ab.dp == 1 - ab = Alphabet.objects.annotate(dp=Field("d", ["b", "a"])).first() + ab = Alphabet.objects.annotate(dp=Field("d", ["b", "a"])).get() assert ab.dp == 2 - ab = Alphabet.objects.annotate(dp=Field("d", ["c", "d"])).first() + ab = Alphabet.objects.annotate(dp=Field("d", ["c", "d"])).get() assert ab.dp == 0 def test_field_expression(self): Alphabet.objects.create(d="b") - ab = Alphabet.objects.annotate(dp=Field("d", [Value("a"), Value("b")])).first() + ab = Alphabet.objects.annotate(dp=Field("d", [Value("a"), Value("b")])).get() assert ab.dp == 2 def test_order_by(self): @@ -287,7 +288,7 @@ def test_xmlextractvalue_expression(self): def test_xmlextractvalue_invalid_xml(self): Alphabet.objects.create(d='{"this": "isNotXML"}') - ab = Alphabet.objects.annotate(ev=XMLExtractValue("d", "/some")).first() + ab = Alphabet.objects.annotate(ev=XMLExtractValue("d", "/some")).get() assert ab.ev == "" @@ -301,7 +302,7 @@ def test_md5_string(self): DeprecationWarning, "This function is deprecated." ): md5 = MD5("d") - ab = Alphabet.objects.annotate(md5=md5).first() + ab = Alphabet.objects.annotate(md5=md5).get() assert ab.md5 == pymd5 @@ -314,7 +315,7 @@ def test_sha1_string(self): DeprecationWarning, "This function is deprecated." ): sha1 = SHA1("d") - ab = Alphabet.objects.annotate(sha=sha1).first() + ab = Alphabet.objects.annotate(sha=sha1).get() assert ab.sha == pysha1 @@ -330,7 +331,7 @@ def test_sha2_string(self): DeprecationWarning, "This function is deprecated." ): sha2 = SHA2("d", hash_len) - ab = Alphabet.objects.annotate(sha=sha2).first() + ab = Alphabet.objects.annotate(sha=sha2).get() assert ab.sha == pysha @@ -343,7 +344,7 @@ def test_sha2_string_hash_len_default(self): DeprecationWarning, "This function is deprecated." ): sha2 = SHA2("d") - ab = Alphabet.objects.annotate(sha=sha2).first() + ab = Alphabet.objects.annotate(sha=sha2).get() assert ab.sha == pysha512 @@ -354,7 +355,7 @@ def test_sha2_bad_hash_len(self): class InformationFunctionTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_last_insert_id(self): Alphabet.objects.create(a=7891) @@ -381,6 +382,8 @@ def test_last_insert_id_in_query(self): class JSONFunctionTests(TestCase): + obj: JSONModel + @classmethod def setUpTestData(cls): super().setUpTestData() @@ -508,35 +511,39 @@ def test_json_length_type(self): def test_json_insert(self): self.obj.attrs = JSONInsert("attrs", {"$.int": 99, "$.int2": 102}) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["int"] == 88 - assert self.obj.attrs["int2"] == 102 + + obj = JSONModel.objects.get() + assert obj.attrs["int"] == 88 + assert obj.attrs["int2"] == 102 def test_json_insert_expression(self): self.obj.attrs = JSONInsert("attrs", {Value("$.int"): Value(99)}) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["int"] == 88 + + obj = JSONModel.objects.get() + assert obj.attrs["int"] == 88 def test_json_insert_dict(self): self.obj.attrs = JSONInsert( "attrs", {"$.sub": {"paper": "drop"}, "$.sub2": {"int": 42, "foo": "bar"}} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["sub"] == {"document": "store"} - assert self.obj.attrs["sub2"]["int"] == 42 - assert self.obj.attrs["sub2"]["foo"] == "bar" + + obj = JSONModel.objects.get() + assert obj.attrs["sub"] == {"document": "store"} + assert obj.attrs["sub2"]["int"] == 42 + assert obj.attrs["sub2"]["foo"] == "bar" def test_json_insert_array(self): self.obj.attrs = JSONInsert( "attrs", {"$.arr": [1, "two", 3], "$.arr2": ["one", 2]} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["arr"] == ["dee", "arr", "arr"] - assert self.obj.attrs["arr2"][0] == "one" - assert self.obj.attrs["arr2"][1] == 2 + + obj = JSONModel.objects.get() + assert obj.attrs["arr"] == ["dee", "arr", "arr"] + assert obj.attrs["arr2"][0] == "one" + assert obj.attrs["arr2"][1] == 2 def test_json_insert_empty_data(self): with pytest.raises(ValueError) as excinfo: @@ -546,27 +553,30 @@ def test_json_insert_empty_data(self): def test_json_replace_pairs(self): self.obj.attrs = JSONReplace("attrs", {"$.int": 101, "$.int2": 102}) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["int"] == 101 - assert "int2" not in self.obj.attrs + + obj = JSONModel.objects.get() + assert obj.attrs["int"] == 101 + assert "int2" not in obj.attrs def test_json_replace_dict(self): self.obj.attrs = JSONReplace( "attrs", {"$.sub": {"paper": "drop"}, "$.sub2": {"int": 42, "foo": "bar"}} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["sub"] == {"paper": "drop"} - assert "sub2" not in self.obj.attrs + + obj = JSONModel.objects.get() + assert obj.attrs["sub"] == {"paper": "drop"} + assert "sub2" not in obj.attrs def test_json_replace_array(self): self.obj.attrs = JSONReplace( "attrs", {"$.arr": [1, "two", 3], "$.arr2": ["one", 2]} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["arr"] == [1, "two", 3] - assert "arr2" not in self.obj.attrs + + obj = JSONModel.objects.get() + assert obj.attrs["arr"] == [1, "two", 3] + assert "arr2" not in obj.attrs def test_json_replace_empty_data(self): with pytest.raises(ValueError) as excinfo: @@ -577,29 +587,32 @@ def test_json_set_pairs(self): with print_all_queries(): self.obj.attrs = JSONSet("attrs", {"$.int": 101, "$.int2": 102}) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["int"] == 101 - assert self.obj.attrs["int2"] == 102 + + obj = JSONModel.objects.get() + assert obj.attrs["int"] == 101 + assert obj.attrs["int2"] == 102 def test_json_set_dict(self): self.obj.attrs = JSONSet( "attrs", {"$.sub": {"paper": "drop"}, "$.sub2": {"int": 42, "foo": "bar"}} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["sub"] == {"paper": "drop"} - assert self.obj.attrs["sub2"]["int"] == 42 - assert self.obj.attrs["sub2"]["foo"] == "bar" + + obj = JSONModel.objects.get() + assert obj.attrs["sub"] == {"paper": "drop"} + assert obj.attrs["sub2"]["int"] == 42 + assert obj.attrs["sub2"]["foo"] == "bar" def test_json_set_array(self): self.obj.attrs = JSONSet( "attrs", {"$.arr": [1, "two", 3], "$.arr2": ["one", 2]} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["arr"] == [1, "two", 3] - assert self.obj.attrs["arr2"][0] == "one" - assert self.obj.attrs["arr2"][1] == 2 + + obj = JSONModel.objects.get() + assert obj.attrs["arr"] == [1, "two", 3] + assert obj.attrs["arr2"][0] == "one" + assert obj.attrs["arr2"][1] == 2 def test_json_set_complex_data(self): data = { @@ -615,8 +628,9 @@ def test_json_set_complex_data(self): } self.obj.attrs = JSONSet("attrs", {"$.data": data}) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["data"] == data + + obj = JSONModel.objects.get() + assert obj.attrs["data"] == data def test_json_set_empty_data(self): with pytest.raises(ValueError) as excinfo: @@ -628,9 +642,10 @@ def test_json_array_append(self): "attrs", {"$.arr": "max", "$.arr[0]": 1.1, "$.sub.document": 3} ) self.obj.save() - self.obj.refresh_from_db() - assert self.obj.attrs["arr"] == [["dee", 1.1], "arr", "arr", "max"] - assert self.obj.attrs["sub"]["document"] == ["store", 3] + + obj = JSONModel.objects.get() + assert obj.attrs["arr"] == [["dee", 1.1], "arr", "arr", "max"] + assert obj.attrs["sub"]["document"] == ["store", 3] class RegexpFunctionTests(TestCase): diff --git a/tests/testapp/test_locks.py b/tests/testapp/test_locks.py index 83c30df4..acf29050 100644 --- a/tests/testapp/test_locks.py +++ b/tests/testapp/test_locks.py @@ -23,7 +23,10 @@ class LockTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} + + supports_lock_info: bool + lock_info_preinstalled: bool @classmethod def setUpClass(cls): @@ -216,7 +219,7 @@ def test_acquire_release(self): class TableLockTests(TransactionTestCase): - databases = ["default", "other"] + databases = {"default", "other"} def tearDown(self): Alphabet.objects.all().delete() diff --git a/tests/testapp/test_operations.py b/tests/testapp/test_operations.py index 2e5201be..12bfe9ab 100644 --- a/tests/testapp/test_operations.py +++ b/tests/testapp/test_operations.py @@ -4,6 +4,7 @@ import pytest from django.db import connection, migrations, models, transaction +from django.db.migrations.operations.base import Operation from django.db.migrations.state import ProjectState from django.test import TransactionTestCase from django.test.utils import CaptureQueriesContext @@ -12,29 +13,31 @@ from django_mysql.test.utils import override_mysql_variables -def plugin_exists(plugin_name): +def plugin_exists(plugin_name: str) -> bool: with connection.cursor() as cursor: cursor.execute( """SELECT COUNT(*) FROM INFORMATION_SCHEMA.PLUGINS WHERE PLUGIN_NAME = %s""", (plugin_name,), ) - return cursor.fetchone()[0] > 0 + count: int = cursor.fetchone()[0] + return count > 0 -def table_storage_engine(table_name): +def table_storage_engine(table_name: str) -> str: with connection.cursor() as cursor: cursor.execute( """SELECT ENGINE FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s""", (table_name,), ) - return cursor.fetchone()[0] + engine: str = cursor.fetchone()[0] + return engine class PluginOperationTests(TransactionTestCase): - databases = ["default", "other"] + databases = {"default", "other"} @classmethod def setUpClass(cls): @@ -168,21 +171,18 @@ def test_running_without_changes(self): operation.database_backwards("test_arstd", editor, new_state, project_state) assert table_storage_engine("test_arstd_pony") == "MyISAM" - # Copied from django core migration tests + # Adapted from django core migration tests: def set_up_test_model( self, - app_label, - second_model=False, - third_model=False, - related_model=False, - mti_model=False, - proxy_model=False, - unique_together=False, - options=False, - db_table=None, - index_together=False, - ): # pragma: no cover + app_label: str, + *, + proxy_model: bool = False, + unique_together: bool = False, + options: bool = False, + db_table: str | None = None, + index_together: bool = False, + ) -> ProjectState: # pragma: no cover """ Creates a test model state and database table. """ @@ -220,7 +220,7 @@ def set_up_test_model( model_options["permissions"] = [("can_groom", "Can groom")] if db_table: model_options["db_table"] = db_table - operations = [ + operations: list[Operation] = [ migrations.CreateModel( "Pony", [ @@ -231,49 +231,6 @@ def set_up_test_model( options=model_options, ) ] - if second_model: - operations.append( - migrations.CreateModel( - "Stable", [("id", models.AutoField(primary_key=True))] - ) - ) - if third_model: - operations.append( - migrations.CreateModel( - "Van", [("id", models.AutoField(primary_key=True))] - ) - ) - if related_model: - operations.append( - migrations.CreateModel( - "Rider", - [ - ("id", models.AutoField(primary_key=True)), - ("pony", models.ForeignKey("Pony")), - ("friend", models.ForeignKey("self")), - ], - ) - ) - if mti_model: - operations.append( - migrations.CreateModel( - "ShetlandPony", - fields=[ - ( - "pony_ptr", - models.OneToOneField( - auto_created=True, - primary_key=True, - to_field="id", - serialize=False, - to="Pony", - ), - ), - ("cuteness", models.IntegerField(default=1)), - ], - bases=["%s.Pony" % app_label], - ) - ) if proxy_model: operations.append( migrations.CreateModel( @@ -286,7 +243,12 @@ def set_up_test_model( return self.apply_operations(app_label, ProjectState(), operations) - def apply_operations(self, app_label, project_state, operations): + def apply_operations( + self, + app_label: str, + project_state: ProjectState, + operations: list[Operation], + ) -> ProjectState: migration = migrations.Migration("name", app_label) migration.operations = operations with connection.schema_editor() as editor: diff --git a/tests/testapp/test_status.py b/tests/testapp/test_status.py index 27ccf242..adce7318 100644 --- a/tests/testapp/test_status.py +++ b/tests/testapp/test_status.py @@ -36,10 +36,11 @@ def test_cast_string(self): class GlobalStatusTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_get(self): running = global_status.get("Threads_running") + assert isinstance(running, int) assert running >= 1 and isinstance(running, int) def test_get_bad_name(self): @@ -126,18 +127,21 @@ def test_other_databases(self): class SessionStatusTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_get_bytes_received(self): bytes_received = session_status.get("Bytes_received") - assert bytes_received >= 0 and isinstance(bytes_received, int) + assert isinstance(bytes_received, int) + assert bytes_received >= 0 bytes_received_2 = session_status.get("Bytes_received") + assert isinstance(bytes_received_2, int) assert bytes_received_2 >= bytes_received def test_get_last_query_cost(self): cost = session_status.get("Last_query_cost") - assert cost >= 0.0 and isinstance(cost, float) + assert isinstance(cost, float) + assert cost >= 0.0 def test_get_bad_name(self): with pytest.raises(ValueError) as excinfo: diff --git a/tests/testapp/test_test_utils.py b/tests/testapp/test_test_utils.py index 61e75c6b..4531d124 100644 --- a/tests/testapp/test_test_utils.py +++ b/tests/testapp/test_test_utils.py @@ -23,7 +23,7 @@ def check_timestamp(self, expected, using="default"): @override_mysql_variables(TIMESTAMP=123) class OverrideVarsClassTest(OverrideVarsMethodTest): - databases = ["default", "other"] + databases = {"default", "other"} def test_class_decorator(self): self.check_timestamp(123) diff --git a/tests/testapp/test_utils.py b/tests/testapp/test_utils.py index 0582eb6b..95fbf5e1 100644 --- a/tests/testapp/test_utils.py +++ b/tests/testapp/test_utils.py @@ -65,7 +65,7 @@ def test_hours(self): class IndexNameTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_requires_field_names(self): with pytest.raises(ValueError) as excinfo: diff --git a/tests/testapp/utils.py b/tests/testapp/utils.py index 0eee19f0..a621ae35 100644 --- a/tests/testapp/utils.py +++ b/tests/testapp/utils.py @@ -1,9 +1,11 @@ from __future__ import annotations from contextlib import contextmanager +from typing import Any import pytest from django.db import DEFAULT_DB_ALIAS, connection, connections +from django.db.backends.utils import CursorWrapper from django.test.utils import CaptureQueriesContext @@ -72,6 +74,6 @@ def used_indexes(query, using=DEFAULT_DB_ALIAS): return {row["key"] for row in fetchall_dicts(cursor) if row["key"] is not None} -def fetchall_dicts(cursor): +def fetchall_dicts(cursor: CursorWrapper) -> list[dict[str, Any]]: columns = [x[0] for x in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()] diff --git a/tests/urls.py b/tests/urls.py index 2aaf7837..a39253fb 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,5 +1,5 @@ from __future__ import annotations -from django.urls import path +from django.urls import URLPattern -urlpatterns: list[path] = [] +urlpatterns: list[URLPattern] = []