diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1822d3..26bc820 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,3 +48,11 @@ repos: hooks: - id: ruff-format - id: ruff + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + additional_dependencies: + - django-stubs==5.0.0 + - django-guardian diff --git a/django_fsm/__init__.py b/django_fsm/__init__.py index aa0779e..2bf51f9 100644 --- a/django_fsm/__init__.py +++ b/django_fsm/__init__.py @@ -7,10 +7,13 @@ import inspect from functools import partialmethod from functools import wraps +from typing import TYPE_CHECKING +from typing import Any from django.apps import apps as django_apps from django.db import models from django.db.models import Field +from django.db.models import QuerySet from django.db.models.query_utils import DeferredAttribute from django.db.models.signals import class_prepared @@ -32,11 +35,41 @@ "RETURN_VALUE", ] +if TYPE_CHECKING: + from collections.abc import Callable + from collections.abc import Collection + from collections.abc import Generator + from collections.abc import Iterable + from collections.abc import Sequence + from typing import Self + + from _typeshed import Incomplete + from django.contrib.auth.models import PermissionsMixin as UserWithPermissions + from django.utils.functional import _StrOrPromise + + _FSMModel = models.Model + _Field = models.Field[Any, Any] + CharField = models.CharField[Any, Any] + IntegerField = models.IntegerField[Any, Any] + ForeignKey = models.ForeignKey[Any, Any] + + _StateValue = str | int + _Permission = str | Callable[[_FSMModel, UserWithPermissions], bool] + _Instance = models.Model # TODO: use real type + +else: + _FSMModel = object + _Field = object + CharField = models.CharField + IntegerField = models.IntegerField + ForeignKey = models.ForeignKey + Self = Any + class TransitionNotAllowed(Exception): """Raised when a transition is not allowed""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.object = kwargs.pop("object", None) self.method = kwargs.pop("method", None) super().__init__(*args, **kwargs) @@ -55,7 +88,16 @@ class ConcurrentTransition(Exception): class Transition: - def __init__(self, method, source, target, on_error, conditions, permission, custom): + def __init__( + self, + method: Callable[..., _StateValue | Any], + source: _StateValue | Sequence[_StateValue] | State, + target: _StateValue, + on_error: _StateValue | None, + conditions: list[Callable[[_Instance], bool]], + permission: str | Callable[[_Instance, UserWithPermissions], bool] | None, + custom: dict[str, _StrOrPromise], + ) -> None: self.method = method self.source = source self.target = target @@ -65,10 +107,10 @@ def __init__(self, method, source, target, on_error, conditions, permission, cus self.custom = custom @property - def name(self): + def name(self) -> str: return self.method.__name__ - def has_perm(self, instance, user): + def has_perm(self, instance: _Instance, user: UserWithPermissions) -> bool: if not self.permission: return True if callable(self.permission): @@ -79,10 +121,10 @@ def has_perm(self, instance, user): return True return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, str): return other == self.name if isinstance(other, Transition): @@ -91,7 +133,7 @@ def __eq__(self, other): return False -def get_available_FIELD_transitions(instance, field): +def get_available_FIELD_transitions(instance: _Instance, field: FSMFieldMixin) -> Generator[Transition, None, None]: """ List of transitions available in current model state with all conditions met @@ -105,14 +147,16 @@ def get_available_FIELD_transitions(instance, field): yield meta.get_transition(curr_state) -def get_all_FIELD_transitions(instance, field): +def get_all_FIELD_transitions(instance: _Instance, field: FSMFieldMixin) -> Generator[Transition, None, None]: """ List of all transitions available in current model state """ return field.get_all_transitions(instance.__class__) -def get_available_user_FIELD_transitions(instance, user, field): +def get_available_user_FIELD_transitions( + instance: _Instance, user: UserWithPermissions, field: FSMFieldMixin +) -> Generator[Transition, None, None]: """ List of transitions available in current model state with all conditions met and user have rights on it @@ -127,11 +171,11 @@ class FSMMeta: Models methods transitions meta information """ - def __init__(self, field, method): + def __init__(self, field: FSMFieldMixin, method: Any) -> None: self.field = field - self.transitions = {} # source -> Transition + self.transitions: dict[str, Transition] = {} # source -> Transition - def get_transition(self, source): + def get_transition(self, source: str) -> Transition | None: transition = self.transitions.get(source, None) if transition is None: transition = self.transitions.get("*", None) @@ -139,7 +183,16 @@ def get_transition(self, source): transition = self.transitions.get("+", None) return transition - def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}): + def add_transition( + self, + method: Callable[..., _StateValue | Any], + source: str, + target: _StateValue, + on_error: _StateValue | None = None, + conditions: list[Callable[[_Instance], bool]] = [], + permission: str | Callable[[_Instance, UserWithPermissions], bool] | None = None, + custom: dict[str, _StrOrPromise] = {}, + ) -> None: if source in self.transitions: raise AssertionError(f"Duplicate transition for {source} state") @@ -153,7 +206,7 @@ def add_transition(self, method, source, target, on_error=None, conditions=[], p custom=custom, ) - def has_transition(self, state): + def has_transition(self, state: str) -> bool: """ Lookup if any transition exists from current model state using current method """ @@ -168,7 +221,7 @@ def has_transition(self, state): return False - def conditions_met(self, instance, state): + def conditions_met(self, instance: _Instance, state: str) -> bool: """ Check if all conditions have been met """ @@ -182,15 +235,15 @@ def conditions_met(self, instance, state): return all(condition(instance) for condition in transition.conditions) - def has_transition_perm(self, instance, state, user): + def has_transition_perm(self, instance: _Instance, state: str, user: UserWithPermissions) -> bool: transition = self.get_transition(state) if not transition: return False - return transition.has_perm(instance, user) + return bool(transition.has_perm(instance, user)) - def next_state(self, current_state): + def next_state(self, current_state: str) -> _StateValue: transition = self.get_transition(current_state) if transition is None: @@ -198,7 +251,7 @@ def next_state(self, current_state): return transition.target - def exception_state(self, current_state): + def exception_state(self, current_state: str) -> _StateValue | None: transition = self.get_transition(current_state) if transition is None: @@ -208,15 +261,15 @@ def exception_state(self, current_state): class FSMFieldDescriptor: - def __init__(self, field): + def __init__(self, field: FSMFieldMixin) -> None: self.field = field - def __get__(self, instance, type=None): + def __get__(self, instance: _Instance, type: Any | None = None) -> Any: if instance is None: return self return self.field.get_state(instance) - def __set__(self, instance, value): + def __set__(self, instance: _Instance, value: Any) -> None: if self.field.protected and self.field.name in instance.__dict__: raise AttributeError(f"Direct {self.field.name} modification is not allowed") @@ -225,12 +278,12 @@ def __set__(self, instance, value): self.field.set_state(instance, value) -class FSMFieldMixin: +class FSMFieldMixin(_Field): descriptor_class = FSMFieldDescriptor - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.protected = kwargs.pop("protected", False) - self.transitions = {} # cls -> (transitions name -> method) + self.transitions: dict[type[_FSMModel], dict[str, Any]] = {} # cls -> (transitions name -> method) self.state_proxy = {} # state -> ProxyClsRef state_choices = kwargs.pop("state_choices", None) @@ -247,21 +300,21 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def deconstruct(self): + def deconstruct(self) -> Any: name, path, args, kwargs = super().deconstruct() if self.protected: kwargs["protected"] = self.protected return name, path, args, kwargs - def get_state(self, instance): + def get_state(self, instance: _Instance) -> Any: # The state field may be deferred. We delegate the logic of figuring this out # and loading the deferred field on-demand to Django's built-in DeferredAttribute class. return DeferredAttribute(self).__get__(instance) - def set_state(self, instance, state): + def set_state(self, instance: _Instance, state: str) -> None: instance.__dict__[self.name] = state - def set_proxy(self, instance, state): + def set_proxy(self, instance: _Instance, state: str) -> None: """ Change class """ @@ -282,7 +335,7 @@ def set_proxy(self, instance, state): instance.__class__ = model - def change_state(self, instance, method, *args, **kwargs): + def change_state(self, instance: _Instance, method: Incomplete, *args: Any, **kwargs: Any) -> Any: meta = method._django_fsm method_name = method.__name__ current_state = self.get_state(instance) @@ -335,7 +388,7 @@ def change_state(self, instance, method, *args, **kwargs): return result - def get_all_transitions(self, instance_cls): + def get_all_transitions(self, instance_cls: type[_FSMModel]) -> Generator[Transition, None, None]: """ Returns [(source, target, name, method)] for all field transitions """ @@ -347,10 +400,10 @@ def get_all_transitions(self, instance_cls): for transition in meta.transitions.values(): yield transition - def contribute_to_class(self, cls, name, **kwargs): + def contribute_to_class(self, cls: type[_FSMModel], name: str, private_only: bool = False, **kwargs: Any) -> None: self.base_cls = cls - super().contribute_to_class(cls, name, **kwargs) + super().contribute_to_class(cls, name, private_only=private_only, **kwargs) setattr(cls, self.name, self.descriptor_class(self)) setattr(cls, f"get_all_{self.name}_transitions", partialmethod(get_all_FIELD_transitions, field=self)) setattr(cls, f"get_available_{self.name}_transitions", partialmethod(get_available_FIELD_transitions, field=self)) @@ -362,13 +415,13 @@ def contribute_to_class(self, cls, name, **kwargs): class_prepared.connect(self._collect_transitions) - def _collect_transitions(self, *args, **kwargs): + def _collect_transitions(self, *args: Any, **kwargs: Any) -> None: sender = kwargs["sender"] if not issubclass(sender, self.base_cls): return - def is_field_transition_method(attr): + def is_field_transition_method(attr: Incomplete) -> bool: return ( (inspect.ismethod(attr) or inspect.isfunction(attr)) and hasattr(attr, "_django_fsm") @@ -391,17 +444,17 @@ def is_field_transition_method(attr): self.transitions[sender] = sender_transitions -class FSMField(FSMFieldMixin, models.CharField): +class FSMField(FSMFieldMixin, CharField): """ State Machine support for Django model as CharField """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("max_length", 50) super().__init__(*args, **kwargs) -class FSMIntegerField(FSMFieldMixin, models.IntegerField): +class FSMIntegerField(FSMFieldMixin, IntegerField): """ Same as FSMField, but stores the state value in an IntegerField. """ @@ -409,31 +462,31 @@ class FSMIntegerField(FSMFieldMixin, models.IntegerField): pass -class FSMKeyField(FSMFieldMixin, models.ForeignKey): +class FSMKeyField(FSMFieldMixin, ForeignKey): """ State Machine support for Django model """ - def get_state(self, instance): + def get_state(self, instance: _Instance) -> Incomplete: return instance.__dict__[self.attname] - def set_state(self, instance, state): + def set_state(self, instance: _Instance, state: str) -> None: instance.__dict__[self.attname] = self.to_python(state) -class FSMModelMixin: +class FSMModelMixin(_FSMModel): """ Mixin that allows refresh_from_db for models with fsm protected fields """ - def _get_protected_fsm_fields(self): - def is_fsm_and_protected(f): + def _get_protected_fsm_fields(self) -> set[str]: + def is_fsm_and_protected(f: object) -> Any: return isinstance(f, FSMFieldMixin) and f.protected - protected_fields = filter(is_fsm_and_protected, self._meta.concrete_fields) + protected_fields: Iterable[Any] = filter(is_fsm_and_protected, self._meta.concrete_fields) # type: ignore[attr-defined, arg-type] return {f.attname for f in protected_fields} - def refresh_from_db(self, *args, **kwargs): + def refresh_from_db(self, *args: Any, **kwargs: Any) -> None: fields = kwargs.pop("fields", None) # Use provided fields, if not set then reload all non-deferred fields.0 @@ -442,13 +495,13 @@ def refresh_from_db(self, *args, **kwargs): protected_fields = self._get_protected_fsm_fields() skipped_fields = deferred_fields.union(protected_fields) - fields = [f.attname for f in self._meta.concrete_fields if f.attname not in skipped_fields] + fields = [f.attname for f in self._meta.concrete_fields if f.attname not in skipped_fields] # type: ignore[attr-defined] kwargs["fields"] = fields super().refresh_from_db(*args, **kwargs) -class ConcurrentTransitionMixin: +class ConcurrentTransitionMixin(_FSMModel): """ Protects a Model from undesirable effects caused by concurrently executed transitions, e.g. running the same transition multiple times at the same time, or running different @@ -474,15 +527,23 @@ class ConcurrentTransitionMixin: state, thus practically negating their effect. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._update_initial_state() @property - def state_fields(self): + def state_fields(self) -> Iterable[Any]: return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields) - def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): + def _do_update( + self, + base_qs: QuerySet[Self], + using: str | None, + pk_val: Any, + values: Collection[tuple[_Field, type[models.Model] | None, Any]], + update_fields: Iterable[str] | None, + forced_update: bool, + ) -> bool: # _do_update is called once for each model class in the inheritance hierarchy. # We can only filter the base_qs on state fields (can be more than one!) present in this particular model. @@ -492,7 +553,7 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat # state filter will be used to narrow down the standard filter checking only PK state_filter = {field.attname: self.__initial_states[field.attname] for field in filter_on} - updated = super()._do_update( + updated: bool = super()._do_update( base_qs=base_qs.filter(**state_filter), using=using, pk_val=pk_val, @@ -512,19 +573,27 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat return updated - def _update_initial_state(self): + def _update_initial_state(self) -> None: self.__initial_states = {field.attname: field.value_from_object(self) for field in self.state_fields} - def refresh_from_db(self, *args, **kwargs): + def refresh_from_db(self, *args: Any, **kwargs: Any) -> None: super().refresh_from_db(*args, **kwargs) self._update_initial_state() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: super().save(*args, **kwargs) self._update_initial_state() -def transition(field, source="*", target=None, on_error=None, conditions=[], permission=None, custom={}): +def transition( + field: FSMFieldMixin, + source: _StateValue | Sequence[_StateValue] = "*", + target: _StateValue | State | None = None, + on_error: _StateValue | None = None, + conditions: list[Callable[[Any], bool]] = [], + permission: _Permission | None = None, + custom: dict[str, _StrOrPromise] = {}, +) -> Callable[[Any], Any]: """ Method decorator to mark allowed transitions. @@ -532,13 +601,14 @@ def transition(field, source="*", target=None, on_error=None, conditions=[], per has not changed after the function call. """ - def inner_transition(func): + def inner_transition(func: Incomplete) -> Incomplete: wrapper_installed, fsm_meta = True, getattr(func, "_django_fsm", None) if not fsm_meta: wrapper_installed = False fsm_meta = FSMMeta(field=field, method=func) setattr(func, "_django_fsm", fsm_meta) + # if isinstance(source, Iterable): if isinstance(source, (list, tuple, set)): for state in source: func._django_fsm.add_transition(func, state, target, on_error, conditions, permission, custom) @@ -546,7 +616,7 @@ def inner_transition(func): func._django_fsm.add_transition(func, source, target, on_error, conditions, permission, custom) @wraps(func) - def _change_state(instance, *args, **kwargs): + def _change_state(instance: _Instance, *args: Any, **kwargs: Any) -> Incomplete: return fsm_meta.field.change_state(instance, func, *args, **kwargs) if not wrapper_installed: @@ -557,7 +627,7 @@ def _change_state(instance, *args, **kwargs): return inner_transition -def can_proceed(bound_method, check_conditions=True): +def can_proceed(bound_method: Incomplete, check_conditions: bool = True) -> bool: """ Returns True if model in state allows to call bound_method @@ -574,7 +644,7 @@ def can_proceed(bound_method, check_conditions=True): return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state)) -def has_transition_perm(bound_method, user): +def has_transition_perm(bound_method: Incomplete, user: UserWithPermissions) -> bool: """ Returns True if model in state allows to call bound_method and user have rights on it """ @@ -585,7 +655,7 @@ def has_transition_perm(bound_method, user): self = bound_method.__self__ current_state = meta.field.get_state(self) - return ( + return bool( meta.has_transition(current_state) and meta.conditions_met(self, current_state) and meta.has_transition_perm(self, current_state, user) @@ -593,15 +663,15 @@ def has_transition_perm(bound_method, user): class State: - def get_state(self, model, transition, result, args=[], kwargs={}): + def get_state(self, model: _FSMModel, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> Incomplete: raise NotImplementedError class RETURN_VALUE(State): - def __init__(self, *allowed_states): + def __init__(self, *allowed_states: Sequence[_StateValue]) -> None: self.allowed_states = allowed_states if allowed_states else None - def get_state(self, model, transition, result, args=[], kwargs={}): + def get_state(self, model: _FSMModel, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> Incomplete: if self.allowed_states is not None: if result not in self.allowed_states: raise InvalidResultState(f"{result} is not in list of allowed states\n{self.allowed_states}") @@ -609,11 +679,13 @@ def get_state(self, model, transition, result, args=[], kwargs={}): class GET_STATE(State): - def __init__(self, func, states=None): + def __init__(self, func: Callable[..., _StateValue | Any], states: Sequence[_StateValue] | None = None) -> None: self.func = func self.allowed_states = states - def get_state(self, model, transition, result, args=[], kwargs={}): + def get_state( + self, model: _FSMModel, transition: Transition, result: _StateValue | Any, args: Any = [], kwargs: Any = {} + ) -> Incomplete: result_state = self.func(model, *args, **kwargs) if self.allowed_states is not None: if result_state not in self.allowed_states: diff --git a/pyproject.toml b/pyproject.toml index 065a619..6beb3ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,14 +67,88 @@ extend-select = [ "RET", "C", # "B", + "TCH", # trailing comma ] -fixable = ["I"] +fixable = ["I", "TCH"] [tool.ruff.lint.isort] force-single-line = true required-imports = ["from __future__ import annotations"] +[tool.django-stubs] +django_settings_module = "tests.settings" + +[tool.mypy] +python_version = 3.11 +plugins = ["mypy_django_plugin.main"] + +# Start off with these +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true + +# Getting these passing should be easy +strict_equality = true +extra_checks = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true + +# These next few are various gradations of forcing use of type annotations +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +[[tool.mypy.overrides]] +module = [ + "tests.*", + "django_fsm.tests.*" +] +ignore_errors = true + +# Start off with these +warn_unused_ignores = true + +# Getting these passing should be easy +strict_equality = false +extra_checks = false + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = false +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_subclassing_any = false +disallow_untyped_decorators = false +disallow_any_generics = false + +# These next few are various gradations of forcing use of type annotations +disallow_untyped_calls = false +disallow_incomplete_defs = false +disallow_untyped_defs = false + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = false + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = false + +[[tool.mypy.overrides]] +module = "django_fsm.management.commands.graph_transitions" +ignore_errors = true + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 844a4f4..a6e35e7 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -93,7 +93,7 @@ class BlogPost(models.Model): state = FSMField(default="new", protected=True) - def can_restore(self, user): + def can_restore(self, user) -> bool: return user.is_superuser or user.is_staff @transition(field=state, source="new", target="published", on_error="failed", permission="testapp.can_publish_post") diff --git a/tests/testapp/tests/test_multidecorators.py b/tests/testapp/tests/test_multidecorators.py index eea9617..697eac3 100644 --- a/tests/testapp/tests/test_multidecorators.py +++ b/tests/testapp/tests/test_multidecorators.py @@ -8,7 +8,7 @@ from django_fsm.signals import post_transition -class TestModel(models.Model): +class MultipletransitionsModel(models.Model): counter = models.IntegerField(default=0) signal_counter = models.IntegerField(default=0) state = FSMField(default="SUBMITTED_BY_USER") @@ -27,12 +27,12 @@ def count_calls(sender, instance, name, source, target, **kwargs): instance.signal_counter += 1 -post_transition.connect(count_calls, sender=TestModel) +post_transition.connect(count_calls, sender=MultipletransitionsModel) class TestStateProxy(TestCase): def test_transition_method_called_once(self): - model = TestModel() + model = MultipletransitionsModel() model.review() self.assertEqual(1, model.counter) self.assertEqual(1, model.signal_counter) diff --git a/tests/testapp/tests/test_transition_all_except_target.py b/tests/testapp/tests/test_transition_all_except_target.py index a7765bf..331fa75 100644 --- a/tests/testapp/tests/test_transition_all_except_target.py +++ b/tests/testapp/tests/test_transition_all_except_target.py @@ -8,7 +8,7 @@ from django_fsm import transition -class TestExceptTargetTransitionShortcut(models.Model): +class ExceptTargetTransitionShortcutModel(models.Model): state = FSMField(default="new") @transition(field=state, source="new", target="published") @@ -25,7 +25,7 @@ class Meta: class Test(TestCase): def setUp(self): - self.model = TestExceptTargetTransitionShortcut() + self.model = ExceptTargetTransitionShortcutModel() def test_usecase(self): self.assertEqual(self.model.state, "new")