Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -198,34 +198,28 @@ extend-select = [
"RUF", # Ruff-specific rules
]
ignore = [
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in __init__
"D200", # One-line docstring should fit on one line
"D202", # No blank lines allowed after function docstring
"D203", # Class definitions that are not preceded by a blank line
"D205", # 1 blank line required between summary line and description
"D209", # Multi-line docstring closing quotes should be on a separate line
"D212", # Multi-line docstring summary should start at the first line
"D213", # Multi-line docstring summary should start at the second line
"D400", # First line should end with a period
"D401", # First line of docstring should be in imperative mood
"D404", # First word of the docstring should not be "This"
"D415", # First line should end with a period, question mark, or exclamation point
"S101", # Use of `assert` detected
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in __init__
"D200", # One-line docstring should fit on one line
"D202", # No blank lines allowed after function docstring
"D203", # Class definitions that are not preceded by a blank line
"D205", # 1 blank line required between summary line and description
"D209", # Multi-line docstring closing quotes should be on a separate line
"D212", # Multi-line docstring summary should start at the first line
"D213", # Multi-line docstring summary should start at the second line
"D400", # First line should end with a period
"D401", # First line of docstring should be in imperative mood
"D404", # First word of the docstring should not be "This"
"D415", # First line should end with a period, question mark, or exclamation point
"S101", # Use of `assert` detected

# TODO - need to fix these
"ANN001", # Missing type annotation for function argument
"ANN002", # Missing type annotation for public function
"ANN003", # Missing type annotation for public method
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
"ANN204", # Missing return type annotation for special method
"ANN401", # Dynamically typed expressions .. are disallowed
"ARG001", # Unused function argument
"ARG002", # Unused method argument
"C901", # .. is too complex
Expand Down
64 changes: 39 additions & 25 deletions pytest_django/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

from collections.abc import Sequence
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable

Expand All @@ -26,11 +25,11 @@ class MessagesTestCase(MessagesTestMixin, TestCase):
test_case = TestCase("run")


def _wrapper(name: str):
def _wrapper(name: str) -> Callable[..., Any]:
func = getattr(test_case, name)

@wraps(func)
def assertion_func(*args, **kwargs):
def assertion_func(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

return assertion_func
Expand All @@ -56,7 +55,12 @@ def assertion_func(*args, **kwargs):


if TYPE_CHECKING:
from collections.abc import Collection, Iterator, Sequence
from contextlib import AbstractContextManager
from typing import overload

from django import forms
from django.db.models import Model, QuerySet, RawQuerySet
from django.http.response import HttpResponseBase

def assertRedirects(
Expand Down Expand Up @@ -111,34 +115,34 @@ def assertTemplateUsed(
template_name: str | None = ...,
msg_prefix: str = ...,
count: int | None = ...,
): ...
) -> None: ...

def assertTemplateNotUsed(
response: HttpResponseBase | str | None = ...,
template_name: str | None = ...,
msg_prefix: str = ...,
): ...
) -> None: ...

def assertRaisesMessage(
expected_exception: type[Exception],
expected_message: str,
*args,
**kwargs,
): ...
*args: Any,
**kwargs: Any,
) -> None: ...

def assertWarnsMessage(
expected_warning: Warning,
expected_message: str,
*args,
**kwargs,
): ...
*args: Any,
**kwargs: Any,
) -> None: ...

def assertFieldOutput(
fieldclass,
valid,
invalid,
field_args=...,
field_kwargs=...,
fieldclass: type[forms.Field],
valid: Any,
invalid: Any,
field_args: Any = ...,
field_kwargs: Any = ...,
empty_value: str = ...,
) -> None: ...

Expand Down Expand Up @@ -194,34 +198,44 @@ def assertXMLNotEqual(

# Removed in Django 5.1: use assertQuerySetEqual.
def assertQuerysetEqual(
qs,
values,
transform=...,
qs: Iterator[Any] | list[Model] | QuerySet | RawQuerySet,
values: Collection[Any],
transform: Callable[[Model], Any] | type[str] | None = ...,
ordered: bool = ...,
msg: str | None = ...,
) -> None: ...

def assertQuerySetEqual(
qs,
values,
transform=...,
qs: Iterator[Any] | list[Model] | QuerySet | RawQuerySet,
values: Collection[Any],
transform: Callable[[Model], Any] | type[str] | None = ...,
ordered: bool = ...,
msg: str | None = ...,
) -> None: ...

@overload
def assertNumQueries(
num: int, func: None = None, *, using: str = ...
) -> AbstractContextManager[None]: ...

@overload
def assertNumQueries(
num: int, func: Callable[..., Any], *args: Any, using: str = ..., **kwargs: Any
) -> None: ...

def assertNumQueries(
num: int,
func=...,
*args,
*args: Any,
using: str = ...,
**kwargs,
**kwargs: Any,
): ...

# Added in Django 5.0.
def assertMessages(
response: HttpResponseBase,
expected_messages: Sequence[Message],
*args,
*args: Any,
ordered: bool = ...,
) -> None: ...

Expand Down
14 changes: 14 additions & 0 deletions pytest_django/django_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,23 @@
# this is the case before you call them.
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest


if TYPE_CHECKING:
from typing import TypeAlias

from django.contrib.auth.models import AbstractBaseUser

_User: TypeAlias = AbstractBaseUser

_UserModel: TypeAlias = type[_User]

__all__ = ("_User", "_UserModel")


def is_django_unittest(request_or_item: pytest.FixtureRequest | pytest.Item) -> bool:
"""Returns whether the request or item is a Django test case."""
from django.test import SimpleTestCase
Expand Down
29 changes: 17 additions & 12 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Generator, Iterable, Sequence
from contextlib import AbstractContextManager, contextmanager
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Union
from typing import TYPE_CHECKING, Callable, Literal, Optional, Protocol, Union

import pytest

Expand All @@ -16,10 +16,13 @@


if TYPE_CHECKING:
from typing import Any, Callable

import django
import django.test

from . import DjangoDbBlocker
from .django_compat import _User, _UserModel


_DjangoDbDatabases = Optional[Union[Literal["__all__"], Iterable[str]]]
Expand Down Expand Up @@ -337,7 +340,7 @@ def __getitem__(self, item: str) -> None:
settings.MIGRATION_MODULES = DisableMigrations()

class MigrateSilentCommand(migrate.Command):
def handle(self, *args, **kwargs):
def handle(self, *args: Any, **kwargs: Any) -> Any:
kwargs["verbosity"] = 0
return super().handle(*args, **kwargs)

Expand Down Expand Up @@ -456,15 +459,15 @@ def async_client() -> django.test.AsyncClient:


@pytest.fixture
def django_user_model(db: None):
def django_user_model(db: None) -> _UserModel:
"""The class of Django's user model."""
from django.contrib.auth import get_user_model

return get_user_model()
return get_user_model() # type: ignore[no-any-return]


@pytest.fixture
def django_username_field(django_user_model) -> str:
def django_username_field(django_user_model: _UserModel) -> str:
"""The fieldname for the username used with Django's user model."""
field: str = django_user_model.USERNAME_FIELD
return field
Expand All @@ -473,9 +476,9 @@ def django_username_field(django_user_model) -> str:
@pytest.fixture
def admin_user(
db: None,
django_user_model,
django_user_model: _User,
django_username_field: str,
):
) -> _User:
"""A Django admin user.

This uses an existing user with username "admin", or creates a new one with
Expand Down Expand Up @@ -504,7 +507,7 @@ def admin_user(
@pytest.fixture
def admin_client(
db: None,
admin_user,
admin_user: _User,
) -> django.test.Client:
"""A Django test client logged in as an admin user."""
from django.test import Client
Expand Down Expand Up @@ -550,14 +553,14 @@ def __delattr__(self, attr: str) -> None:

self._to_restore.append(override)

def __setattr__(self, attr: str, value) -> None:
def __setattr__(self, attr: str, value: Any) -> None:
from django.test import override_settings

override = override_settings(**{attr: value})
override.enable()
self._to_restore.append(override)

def __getattr__(self, attr: str):
def __getattr__(self, attr: str) -> Any:
from django.conf import settings

return getattr(settings, attr)
Expand All @@ -570,7 +573,7 @@ def finalize(self) -> None:


@pytest.fixture
def settings():
def settings() -> Generator[SettingsWrapper, None, None]:
"""A Django settings object which restores changes after the testrun"""
skip_if_no_django()

Expand All @@ -580,7 +583,9 @@ def settings():


@pytest.fixture(scope="session")
def live_server(request: pytest.FixtureRequest):
def live_server(
request: pytest.FixtureRequest,
) -> Generator[live_server_helper.LiveServer, None, None]:
"""Run a live Django server in the background during tests

The address the server is started from is taken from the
Expand Down
2 changes: 1 addition & 1 deletion pytest_django/live_server_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def url(self) -> str:
def __str__(self) -> str:
return self.url

def __add__(self, other) -> str:
def __add__(self, other: str) -> str:
return f"{self}{other}"

def __repr__(self) -> str:
Expand Down
12 changes: 7 additions & 5 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Generator
from contextlib import AbstractContextManager
from functools import reduce
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING

import pytest

Expand Down Expand Up @@ -54,6 +54,8 @@


if TYPE_CHECKING:
from typing import Any, NoReturn

import django


Expand Down Expand Up @@ -186,7 +188,7 @@ def _handle_import_error(extra_message: str) -> Generator[None, None, None]:
raise ImportError(msg) from None


def _add_django_project_to_path(args) -> str:
def _add_django_project_to_path(args: list[str]) -> str:
def is_django_project(path: pathlib.Path) -> bool:
try:
return path.is_dir() and (path / "manage.py").exists()
Expand All @@ -198,7 +200,7 @@ def arg_to_path(arg: str) -> pathlib.Path:
arg = arg.split("::", 1)[0]
return pathlib.Path(arg)

def find_django_path(args) -> pathlib.Path | None:
def find_django_path(args: list[str]) -> pathlib.Path | None:
str_args = (str(arg) for arg in args)
path_args = [arg_to_path(x) for x in str_args if not x.startswith("-")]

Expand Down Expand Up @@ -571,7 +573,7 @@ def _django_setup_unittest(

original_runtest = TestCaseFunction.runtest

def non_debugging_runtest(self) -> None:
def non_debugging_runtest(self) -> None: # noqa: ANN001
self._testcase(result=self)

from django.test import SimpleTestCase
Expand Down Expand Up @@ -831,7 +833,7 @@ def _dj_db_wrapper(self) -> django.db.backends.base.base.BaseDatabaseWrapper:
def _save_active_wrapper(self) -> None:
self._history.append(self._dj_db_wrapper.ensure_connection)

def _blocking_wrapper(*args, **kwargs) -> NoReturn:
def _blocking_wrapper(*args: Any, **kwargs: Any) -> NoReturn:
__tracebackhide__ = True
raise RuntimeError(
"Database access not allowed, "
Expand Down
Loading