diff --git a/.gitignore b/.gitignore index 7b4e7d3..93e32cd 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ test.db # django fsm command tests exports/* + +# codex +.codex/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87ca8ef..9874da3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,3 +45,11 @@ repos: hooks: - id: ruff-format - id: ruff-check + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.19.1 + hooks: + - id: mypy + additional_dependencies: + - django-stubs==5.2.9 + - django-guardian diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 32517e7..b3d61d4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ Changelog ========= +UNRELEASED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Add typing +- Add logging solution + + django-fsm-2 4.1.0 2025-11-03 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/README.md b/README.md index 2b15058..6f54bdb 100644 --- a/README.md +++ b/README.md @@ -418,6 +418,45 @@ executed transitions, make sure: Following these recommendations, `ConcurrentTransitionMixin` will cause a rollback of all changes executed in an inconsistent state. +## Transition tracking + +Use `@django_fsm.track()` to write state changes to a log table. +By default, it writes to `django_fsm.StateLog` (single table). +If you prefer one table per model, define your own log model and pass it in. +You can also capture `author` and `description` for each transition. + +```python +import django_fsm +from django_fsm.log import fsm_log_by +from django_fsm.log import fsm_log_description +from django.db import models + + +@django_fsm.track() +class BlogPost(models.Model): + state = django_fsm.FSMField(default="new") + + @fsm_log_by + @fsm_log_description + @django_fsm.transition(field=state, source="new", target="published") + def publish(self): + pass +``` + +```python +import django_fsm +from django.db import models + + +class BlogPostLog(django_fsm.TransitionLogBase): + post = models.ForeignKey("BlogPost", on_delete=models.CASCADE, related_name="transition_logs") + + +@django_fsm.track(log_model=BlogPostLog, relation_field="post") +class BlogPost(models.Model): + state = django_fsm.FSMField(default="new") +``` + ## Drawing transitions Render a graphical overview of your model transitions. @@ -460,7 +499,6 @@ INSTALLED_APPS = ( ## Extensions - Admin integration: -- Transition logging: ## Contributing diff --git a/django_fsm/__init__.py b/django_fsm/__init__.py index 9ef3454..d2553f1 100644 --- a/django_fsm/__init__.py +++ b/django_fsm/__init__.py @@ -5,6 +5,7 @@ from __future__ import annotations import inspect +import typing from functools import partialmethod from functools import wraps @@ -12,12 +13,48 @@ from django.apps import apps as django_apps from django.db import models from django.db.models import Field +from django.db.models import QuerySet from django.db.models.query_utils import DeferredAttribute from django.db.models.signals import class_prepared from django_fsm.signals import post_transition from django_fsm.signals import pre_transition +if typing.TYPE_CHECKING: # pragma: no cover + from collections.abc import Callable + from collections.abc import Collection + from collections.abc import Generator + from collections.abc import Iterable + from collections.abc import Sequence + from typing import Self + + from django.contrib.auth.models import PermissionsMixin as UserWithPermissions + from django.utils.functional import _StrOrPromise + + _Field = models.Field[typing.Any, typing.Any] + CharField = models.CharField[str, str] + IntegerField = models.IntegerField[int, int] + ForeignKey = models.ForeignKey[typing.Any, typing.Any] + + _FSMModel = models.Model + _StateValue: typing.TypeAlias = str | int + _Permission: typing.TypeAlias = str | Callable[[_FSMModel, typing.Any], bool] + _Condition: typing.TypeAlias = Callable[[models.Model], bool] + +else: + _FSMModel = object + _Field = object + CharField = models.CharField + IntegerField = models.IntegerField + ForeignKey = models.ForeignKey + Self = typing.Any + +try: + from typing import override +except ImportError: # pragma: no cover + # Py<3.12 + from typing_extensions import override + __all__ = [ "GET_STATE", "RETURN_VALUE", @@ -37,7 +74,8 @@ class TransitionNotAllowed(Exception): # noqa: N818 """Raised when a transition is not allowed""" - def __init__(self, *args, **kwargs): + @override + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: self.object = kwargs.pop("object", None) self.method = kwargs.pop("method", None) super().__init__(*args, **kwargs) @@ -56,7 +94,16 @@ class ConcurrentTransition(Exception): # noqa: N818 class Transition: - def __init__(self, method, source, target, on_error, conditions, permission, custom): + def __init__( + self, + method: Callable[..., _StateValue], + source: _StateValue, + target: _StateValue, + on_error: _StateValue | None, + conditions: list[_Condition] | None, + permission: _Permission | None, + custom: dict[str, _StrOrPromise] | None, + ) -> None: self.method = method self.source = source self.target = target @@ -66,10 +113,10 @@ def __init__(self, method, source, target, on_error, conditions, permission, cus self.custom = custom @property - def name(self): + def name(self) -> str: return self.method.__name__ - def has_perm(self, instance, user): + def has_perm(self, instance: _FSMModel, user: UserWithPermissions) -> bool: if not self.permission: return True if callable(self.permission): @@ -80,10 +127,10 @@ def has_perm(self, instance, user): return True return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, str): return other == self.name if isinstance(other, Transition): @@ -92,7 +139,9 @@ def __eq__(self, other): return False -def get_available_FIELD_transitions(instance, field): # noqa: N802 +def get_available_FIELD_transitions( # noqa: N802 + instance: _FSMModel, field: FSMFieldMixin +) -> Generator[Transition]: """ List of transitions available in current model state with all conditions met @@ -106,14 +155,16 @@ def get_available_FIELD_transitions(instance, field): # noqa: N802 yield meta.get_transition(curr_state) -def get_all_FIELD_transitions(instance, field): # noqa: N802 +def get_all_FIELD_transitions(instance: _FSMModel, field: FSMFieldMixin) -> Generator[Transition]: # noqa: N802 """ List of all transitions available in current model state """ return field.get_all_transitions(instance.__class__) -def get_available_user_FIELD_transitions(instance, user, field): # noqa: N802 +def get_available_user_FIELD_transitions( # noqa: N802 + instance: _FSMModel, user: UserWithPermissions, field: FSMFieldMixin +) -> Generator[Transition]: """ List of transitions available in current model state with all conditions met and user have rights on it @@ -128,11 +179,11 @@ class FSMMeta: Models methods transitions meta information """ - def __init__(self, field, method): + def __init__(self, field: FSMFieldMixin | str, method: bool) -> None: # noqa: FBT001 self.field = field - self.transitions = {} # source -> Transition + self.transitions: dict[_StateValue, Transition] = {} # source -> Transition - def get_transition(self, source): + def get_transition(self, source: _StateValue) -> Transition | None: transition = self.transitions.get(source, None) if transition is None: transition = self.transitions.get("*", None) @@ -141,8 +192,15 @@ def get_transition(self, source): return transition def add_transition( - self, method, source, target, on_error=None, conditions=[], permission=None, custom={} - ): + self, + method: Callable[..., _StateValue], + source: _StateValue, + target: _StateValue, + on_error: _StateValue | None = None, + conditions: list[_Condition] | None = None, + permission: str | Callable[[_FSMModel, UserWithPermissions], bool] | None = None, + custom: dict[str, _StrOrPromise] | None = None, + ) -> None: if source in self.transitions: raise AssertionError(f"Duplicate transition for {source} state") @@ -156,7 +214,7 @@ def add_transition( custom=custom, ) - def has_transition(self, state): + def has_transition(self, state: _StateValue) -> bool: """ Lookup if any transition exists from current model state using current method """ @@ -171,7 +229,7 @@ def has_transition(self, state): return False - def conditions_met(self, instance, state): + def conditions_met(self, instance: _FSMModel, state: _StateValue) -> bool: """ Check if all conditions have been met """ @@ -185,7 +243,9 @@ def conditions_met(self, instance, state): return all(condition(instance) for condition in transition.conditions) - def has_transition_perm(self, instance, state, user): + def has_transition_perm( + self, instance: _FSMModel, state: _StateValue, user: UserWithPermissions + ) -> bool: transition = self.get_transition(state) if not transition: @@ -193,7 +253,7 @@ def has_transition_perm(self, instance, state, user): return transition.has_perm(instance, user) - def next_state(self, current_state): + def next_state(self, current_state: _StateValue) -> _StateValue: transition = self.get_transition(current_state) if transition is None: @@ -201,7 +261,7 @@ def next_state(self, current_state): return transition.target - def exception_state(self, current_state): + def exception_state(self, current_state: _StateValue) -> _StateValue | None: transition = self.get_transition(current_state) if transition is None: @@ -211,15 +271,15 @@ def exception_state(self, current_state): class FSMFieldDescriptor: - def __init__(self, field): + def __init__(self, field: FSMFieldMixin) -> None: self.field = field - def __get__(self, instance, instance_type=None): + def __get__(self, instance: _FSMModel, type: typing.Any | None = None) -> typing.Any: # noqa: A002 if instance is None: return self return self.field.get_state(instance) - def __set__(self, instance, value): + def __set__(self, instance: _FSMModel, value: typing.Any) -> None: if self.field.protected and self.field.name in instance.__dict__: raise AttributeError(f"Direct {self.field.name} modification is not allowed") @@ -228,13 +288,16 @@ def __set__(self, instance, value): self.field.set_state(instance, value) -class FSMFieldMixin: +class FSMFieldMixin(_Field): descriptor_class = FSMFieldDescriptor - def __init__(self, *args, **kwargs): + @override + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: self.protected = kwargs.pop("protected", False) - self.transitions = {} # cls -> (transitions name -> method) - self.state_proxy = {} # state -> ProxyClsRef + self.transitions: dict[ + type[_FSMModel], dict[str, typing.Any] + ] = {} # cls -> (transitions name -> method) + self.state_proxy: dict[_StateValue, str] = {} # state -> ProxyClsRef state_choices = kwargs.pop("state_choices", None) choices = kwargs.get("choices") @@ -250,21 +313,22 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def deconstruct(self): + @override + def deconstruct(self) -> tuple[str, str, Sequence[typing.Any], dict[str, typing.Any]]: name, path, args, kwargs = super().deconstruct() if self.protected: kwargs["protected"] = self.protected return name, path, args, kwargs - def get_state(self, instance): + def get_state(self, instance: _FSMModel) -> typing.Any: # The state field may be deferred. We delegate the logic of figuring this out # and loading the deferred field on-demand to Django's built-in DeferredAttribute class. return DeferredAttribute(self).__get__(instance) - def set_state(self, instance, state): + def set_state(self, instance: _FSMModel, state: _StateValue) -> None: instance.__dict__[self.name] = state - def set_proxy(self, instance, state): + def set_proxy(self, instance: _FSMModel, state: _StateValue) -> None: """ Change class """ @@ -285,7 +349,9 @@ def set_proxy(self, instance, state): instance.__class__ = model - def change_state(self, instance, method, *args, **kwargs): + def change_state( + self, instance: _FSMModel, method: typing.Any, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: meta = method._django_fsm method_name = method.__name__ current_state = self.get_state(instance) @@ -342,21 +408,26 @@ def change_state(self, instance, method, *args, **kwargs): return result - def get_all_transitions(self, instance_cls): + def get_all_transitions(self, instance_cls: type[_FSMModel]) -> Generator[Transition]: """ Returns [(source, target, name, method)] for all field transitions """ transitions = self.transitions[instance_cls] for transition in transitions.values(): - meta = transition._django_fsm - - yield from meta.transitions.values() - - def contribute_to_class(self, cls, name, **kwargs): + yield from transition._django_fsm.transitions.values() + + @override + def contribute_to_class( + self, + cls: type[_FSMModel], + name: str, + private_only: bool = False, + **kwargs: typing.Any, + ) -> None: self.base_cls = cls - super().contribute_to_class(cls, name, **kwargs) + super().contribute_to_class(cls, name, private_only=private_only, **kwargs) setattr(cls, self.name, self.descriptor_class(self)) setattr( cls, @@ -376,13 +447,13 @@ def contribute_to_class(self, cls, name, **kwargs): class_prepared.connect(self._collect_transitions) - def _collect_transitions(self, *args, **kwargs): + def _collect_transitions(self, *args: typing.Any, **kwargs: typing.Any) -> None: sender = kwargs["sender"] if not issubclass(sender, self.base_cls): return - def is_field_transition_method(attr): + def is_field_transition_method(attr: Callable[[typing.Any], typing.Any]) -> bool: return ( (inspect.ismethod(attr) or inspect.isfunction(attr)) and hasattr(attr, "_django_fsm") @@ -396,7 +467,7 @@ def is_field_transition_method(attr): ) ) - sender_transitions = {} + sender_transitions: dict[str, typing.Any] = {} transitions = inspect.getmembers(sender, predicate=is_field_transition_method) for method_name, method in transitions: method._django_fsm.field = self @@ -405,59 +476,61 @@ def is_field_transition_method(attr): self.transitions[sender] = sender_transitions -class FSMField(FSMFieldMixin, models.CharField): +class FSMField(FSMFieldMixin, CharField): """ State Machine support for Django model as CharField """ - def __init__(self, *args, **kwargs): + @override + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: kwargs.setdefault("max_length", 50) super().__init__(*args, **kwargs) -class FSMIntegerField(FSMFieldMixin, models.IntegerField): +class FSMIntegerField(FSMFieldMixin, IntegerField): """ Same as FSMField, but stores the state value in an IntegerField. """ -class FSMKeyField(FSMFieldMixin, models.ForeignKey): +class FSMKeyField(FSMFieldMixin, ForeignKey): """ State Machine support for Django model """ - def get_state(self, instance): + def get_state(self, instance: _FSMModel) -> typing.Any: return instance.__dict__[self.attname] - def set_state(self, instance, state): + def set_state(self, instance: _FSMModel, state: _StateValue) -> None: instance.__dict__[self.attname] = self.to_python(state) -class FSMModelMixin: +class FSMModelMixin(_FSMModel): """ Mixin that allows refresh_from_db for models with fsm protected fields """ - def _get_protected_fsm_fields(self): - def is_fsm_and_protected(f): + def _get_protected_fsm_fields(self) -> set[str]: + def is_fsm_and_protected(f: object) -> typing.Any: return isinstance(f, FSMFieldMixin) and f.protected protected_fields = filter(is_fsm_and_protected, self._meta.concrete_fields) return {f.attname for f in protected_fields} - def refresh_from_db(self, *args, **kwargs): + @override + def refresh_from_db(self, *args: typing.Any, **kwargs: typing.Any) -> None: protected_fields = self._get_protected_fsm_fields() for f in protected_fields: - self._meta.get_field(f).protected = False + setattr(self._meta.get_field(f), "protected", False) super().refresh_from_db(*args, **kwargs) for f in protected_fields: - self._meta.get_field(f).protected = True + setattr(self._meta.get_field(f), "protected", True) -class ConcurrentTransitionMixin: +class ConcurrentTransitionMixin(_FSMModel): """ Protects a Model from undesirable effects caused by concurrently executed transitions, e.g. running the same transition multiple times at the same time, or running different @@ -468,13 +541,13 @@ class ConcurrentTransitionMixin: This scheme is not that strict as true *optimistic locking* mechanism, it is however more lightweight - leveraging the specifics of FSM models. - Instance of a model based on this Mixin will be prevented from saving into DB if any + Instance of a model based on this Mixin will be prevented from saving into DB if typing.Any of its state fields (instances of FSMFieldMixin) has been changed since the object was fetched from the database. *ConcurrentTransition* exception will be raised in such cases. For guaranteed protection against such race conditions, make sure: - * Your transitions do not have any side effects except for changes in the database, + * Your transitions do not have typing.Any side effects except for changes in the database, * You always run the save() method on the object within django.db.transaction.atomic() block. @@ -483,17 +556,26 @@ class ConcurrentTransitionMixin: state, thus practically negating their effect. """ - def __init__(self, *args, **kwargs): + @override + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: super().__init__(*args, **kwargs) self._update_initial_state() @property - def state_fields(self): - return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields) + def state_fields(self) -> Iterable[FSMFieldMixin]: + return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields) # type: ignore[arg-type] + @override def _do_update( - self, base_qs, using, pk_val, values, update_fields, forced_update, returning_fields=None - ): + self, + base_qs: QuerySet[Self], + using: str | None, + pk_val: typing.Any, + values: Collection[tuple[_Field, type[models.Model] | None, typing.Any]], + update_fields: Iterable[str] | None, + forced_update: bool, + returning_fields: bool | None = None, + ) -> bool: # _do_update is called once for each model class in the inheritance hierarchy. We can only # filter the base_qs on state fields (can be more than one!) present in this specific model. @@ -505,7 +587,7 @@ def _do_update( # Django 6.0+ added returning_fields parameter to _do_update if DJANGO_VERSION >= (6, 0): - updated = super()._do_update( + updated = super()._do_update( # type: ignore[call-arg] base_qs=base_qs.filter(**state_filter), using=using, pk_val=pk_val, @@ -539,23 +621,31 @@ def _do_update( return updated - def _update_initial_state(self): + def _update_initial_state(self) -> None: self.__initial_states = { field.attname: field.value_from_object(self) for field in self.state_fields } - def refresh_from_db(self, *args, **kwargs): + @override + def refresh_from_db(self, *args: typing.Any, **kwargs: typing.Any) -> None: super().refresh_from_db(*args, **kwargs) self._update_initial_state() - def save(self, *args, **kwargs): + @override + def save(self, *args: typing.Any, **kwargs: typing.Any) -> None: super().save(*args, **kwargs) self._update_initial_state() def transition( - field, source="*", target=None, on_error=None, conditions=[], permission=None, custom={} -): + field: FSMFieldMixin | str, + source: _StateValue | Sequence[_StateValue] = "*", + target: _StateValue | State | None = None, + on_error: _StateValue | None = None, + conditions: list[_Condition] | None = None, + permission: _Permission | None = None, + custom: dict[str, _StrOrPromise] | None = None, +) -> Callable[[typing.Any], typing.Any]: """ Method decorator to mark allowed transitions. @@ -563,7 +653,7 @@ def transition( has not changed after the function call. """ - def inner_transition(func): + def inner_transition(func: typing.Any) -> typing.Any: wrapper_installed, fsm_meta = True, getattr(func, "_django_fsm", None) if not fsm_meta: wrapper_installed = False @@ -581,7 +671,10 @@ def inner_transition(func): ) @wraps(func) - def _change_state(instance, *args, **kwargs): + def _change_state( + instance: _FSMModel, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + assert isinstance(fsm_meta.field, FSMFieldMixin) return fsm_meta.field.change_state(instance, func, *args, **kwargs) if not wrapper_installed: @@ -592,7 +685,7 @@ def _change_state(instance, *args, **kwargs): return inner_transition -def can_proceed(bound_method, check_conditions=True): # noqa: FBT002 +def can_proceed(bound_method: typing.Any, check_conditions: bool = True) -> bool: # noqa: FBT001, FBT002 """ Returns True if model in state allows to call bound_method @@ -611,7 +704,7 @@ def can_proceed(bound_method, check_conditions=True): # noqa: FBT002 ) -def has_transition_perm(bound_method, user): +def has_transition_perm(bound_method: typing.Any, user: UserWithPermissions) -> bool: """ Returns True if model in state allows to call bound_method and user have rights on it """ @@ -622,7 +715,7 @@ def has_transition_perm(bound_method, user): self = bound_method.__self__ current_state = meta.field.get_state(self) - return ( + return bool( meta.has_transition(current_state) and meta.conditions_met(self, current_state) and meta.has_transition_perm(self, current_state, user) @@ -630,15 +723,31 @@ def has_transition_perm(bound_method, user): class State: - def get_state(self, model, transition, result, args=[], kwargs={}): + allowed_states: Sequence[_StateValue] + + def get_state( + self, + model: _FSMModel, + transition: Transition, + result: typing.Any, + args: Sequence[typing.Any] | None = None, + kwargs: dict[str, typing.Any] | None = None, + ) -> typing.Any: raise NotImplementedError class RETURN_VALUE(State): # noqa: N801 - def __init__(self, *allowed_states): - self.allowed_states = allowed_states if allowed_states else None - - def get_state(self, model, transition, result, args=[], kwargs={}): + def __init__(self, *allowed_states: _StateValue) -> None: + self.allowed_states = allowed_states or [] + + def get_state( + self, + model: _FSMModel, + transition: Transition, + result: typing.Any, + args: Sequence[typing.Any] | None = None, + kwargs: dict[str, typing.Any] | None = None, + ) -> typing.Any: if self.allowed_states is not None and result not in self.allowed_states: raise InvalidResultState( f"{result} is not in list of allowed states\n{self.allowed_states}" @@ -647,11 +756,26 @@ def get_state(self, model, transition, result, args=[], kwargs={}): class GET_STATE(State): # noqa: N801 - def __init__(self, func, states=None): + def __init__( + self, + func: Callable[..., _StateValue], + states: Sequence[_StateValue] | None = None, + ) -> None: self.func = func - self.allowed_states = states - - def get_state(self, model, transition, result, args=[], kwargs={}): + self.allowed_states = states or [] + + def get_state( + self, + model: _FSMModel, + transition: Transition, + result: _StateValue, + args: Sequence[typing.Any] | None = None, + kwargs: dict[str, typing.Any] | None = None, + ) -> typing.Any: + if args is None: + args = () + if kwargs is None: + kwargs = {} result_state = self.func(model, *args, **kwargs) if self.allowed_states is not None and result_state not in self.allowed_states: raise InvalidResultState( diff --git a/django_fsm/admin.py b/django_fsm/admin.py new file mode 100644 index 0000000..c697201 --- /dev/null +++ b/django_fsm/admin.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing +from warnings import warn + +from django.contrib.contenttypes.admin import GenericTabularInline +from django.db.models import F + +from .models import StateLog + +if typing.TYPE_CHECKING: + from django.db.models import QuerySet + from django.http import HttpRequest + + from .models import TransitionLogBase + + +class FSMTransitionInline(GenericTabularInline): + model: type[TransitionLogBase] = None # type: ignore[assignment] + + can_delete = False + + def has_add_permission( + self, request: HttpRequest, obj: TransitionLogBase | None = None + ) -> bool: + return False + + def has_change_permission( + self, request: HttpRequest, obj: TransitionLogBase | None = None + ) -> bool: + return True + + fields = ( + "transition", + "source_state", + "state", + "by", + "description", + "timestamp", + ) + + def get_readonly_fields( + self, request: HttpRequest, obj: TransitionLogBase | None = None + ) -> list[str] | tuple[str, ...] | tuple[()]: + return self.fields + + def get_queryset(self, request: HttpRequest) -> QuerySet[TransitionLogBase]: + return super().get_queryset(request).order_by(F("timestamp").desc()) + + +class StateLogInline(FSMTransitionInline): + model = StateLog + + def __init__(self, parent_model: typing.Any, admin_site: typing.Any) -> None: + warn( + "StateLogInline has been deprecated by PersistedTransitionInline.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(parent_model, admin_site) diff --git a/django_fsm/apps.py b/django_fsm/apps.py new file mode 100644 index 0000000..2ce1321 --- /dev/null +++ b/django_fsm/apps.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from django.apps import AppConfig + + +class DjangoFSMAppConfig(AppConfig): + name = "django_fsm" + verbose_name = "Django FSM" diff --git a/django_fsm/log.py b/django_fsm/log.py new file mode 100644 index 0000000..6c3daed --- /dev/null +++ b/django_fsm/log.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import contextlib +import typing +from dataclasses import dataclass +from functools import partial +from functools import wraps + +from django.contrib.contenttypes.models import ContentType +from django.db import models + +from .models import StateLog +from .models import TransitionLogBase +from .signals import post_transition + +if typing.TYPE_CHECKING: # pragma: no cover + from collections.abc import Callable + + from . import _Field + + +__all__ = [ + "StateLog", + "TransitionLogBase", + "fsm_log_by", + "fsm_log_description", + "track", +] + + +@dataclass(frozen=True) +class TrackConfig: + log_model: type[TransitionLogBase] | None + relation_field: str | None + + +_registry: dict[type[models.Model], TrackConfig] = {} +NOTSET = object() + + +def track( + *, + log_model: type[TransitionLogBase] | None = None, + relation_field: str | None = None, +) -> Callable[[type[models.Model]], type[models.Model]]: + def decorator(model_cls: type[models.Model]) -> type[models.Model]: + if model_cls._meta.abstract: + raise TypeError("django_fsm.track cannot be used with abstract models") + config = TrackConfig(log_model=log_model, relation_field=relation_field) + _registry[model_cls] = config + + post_transition.connect( + _log_transition, + sender=model_cls, + dispatch_uid=f"django_fsm.track.{model_cls._meta.label_lower}", + weak=False, + ) + return model_cls + + return decorator + + +class FSMLogDescriptor: + ATTR_PREFIX = "__django_fsm_log_attr_" + + def __init__(self, instance: models.Model, attribute: str, value: typing.Any = NOTSET): + self.instance = instance + self.attribute = attribute + if value is not NOTSET: + self.set(value) + + def get(self) -> typing.Any: + return getattr(self.instance, self.ATTR_PREFIX + self.attribute) + + def set(self, value: typing.Any) -> None: + setattr(self.instance, self.ATTR_PREFIX + self.attribute, value) + + def __enter__(self) -> typing.Self: + return self + + def __exit__(self, *args: object) -> None: + with contextlib.suppress(AttributeError): + delattr(self.instance, self.ATTR_PREFIX + self.attribute) + + +def fsm_log_by(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]: + @wraps(func) + def wrapped(instance: models.Model, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + if "by" in kwargs: + author = kwargs.pop("by") + else: + return func(instance, *args, **kwargs) + + with FSMLogDescriptor(instance, "by", author): + return func(instance, *args, **kwargs) + + return wrapped + + +def fsm_log_description( + func: typing.Callable[..., typing.Any] | None = None, + *, + description: str | None = None, +) -> typing.Callable[..., typing.Any]: + if func is None: + return partial(fsm_log_description, description=description) + + @wraps(func) + def wrapped(instance: models.Model, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + with FSMLogDescriptor(instance, "description") as descriptor: + if "description" in kwargs: + descriptor.set(kwargs.pop("description")) + else: + descriptor.set(description) + return func(instance, *args, **kwargs) + + return wrapped + + +def _log_transition( + sender: type[models.Model], + instance: models.Model, + name: str, + source: typing.Any, + target: typing.Any, + field: _Field, + **kwargs: typing.Any, +) -> None: + config = _registry.get(sender) + if not config or instance.pk is None: + return + + log_model = config.log_model or StateLog + log_kwargs: dict[str, typing.Any] = { + "transition": name, + "state_field": field.name, + "source_state": _coerce_state(source), + "state": _coerce_state(target), + "by": _extract_log_value(instance, "by"), + "description": _extract_log_value(instance, "description"), + } + + if issubclass(log_model, StateLog): + log_kwargs["content_type"] = ContentType.objects.get_for_model(sender) + log_kwargs["object_id"] = str(instance.pk) + else: + relation_field = config.relation_field or _resolve_relation_field(log_model, sender) + log_kwargs[relation_field] = instance + + log_model._default_manager.using(instance._state.db).create(**log_kwargs) + + +def _resolve_relation_field( + log_model: type[TransitionLogBase], model_cls: type[models.Model] +) -> str: + relation_fields = [ + field.name + for field in log_model._meta.fields + if isinstance(field, models.ForeignKey) + and _matches_model(field.remote_field.model, model_cls) + ] + if len(relation_fields) == 1: + return relation_fields[0] + + if not relation_fields: + raise ValueError( + f"{log_model.__name__} does not define a ForeignKey to {model_cls.__name__}" + ) + raise ValueError( + f"{log_model.__name__} has multiple ForeignKey fields to {model_cls.__name__}; " + "set relation_field when calling track()" + ) + + +def _coerce_state(value: typing.Any) -> str | None: + if value is None: + return None + if isinstance(value, models.Model): + return str(value.pk) + return str(value) + + +def _matches_model(remote_model: typing.Any, model_cls: type[models.Model]) -> bool: + if remote_model == model_cls: + return True + if isinstance(remote_model, str): + return remote_model == model_cls.__name__ or remote_model.endswith(f".{model_cls.__name__}") + return False + + +def _extract_log_value( + instance: models.Model, + attribute: str, +) -> typing.Any: + try: + return FSMLogDescriptor(instance, attribute).get() + except AttributeError: + return None diff --git a/django_fsm/management/commands/graph_transitions.py b/django_fsm/management/commands/graph_transitions.py index 0bfd095..23a7d48 100644 --- a/django_fsm/management/commands/graph_transitions.py +++ b/django_fsm/management/commands/graph_transitions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from itertools import chain import graphviz @@ -11,36 +12,59 @@ from django_fsm import RETURN_VALUE from django_fsm import FSMFieldMixin +if typing.TYPE_CHECKING: # pragma: no cover + from argparse import ArgumentParser + from collections.abc import Sequence -def all_fsm_fields_data(model): + from django.db import models + + from django_fsm import _StateValue + + +def all_fsm_fields_data( + model: type[models.Model], +) -> list[tuple[FSMFieldMixin, type[models.Model]]]: return [ (field, model) for field in model._meta.get_fields() if isinstance(field, FSMFieldMixin) ] -def one_fsm_fields_data(model, field_name): - return (model._meta.get_field(field_name), model) +def one_fsm_fields_data( + model: type[models.Model], field_name: str +) -> tuple[FSMFieldMixin, type[models.Model]]: + field = model._meta.get_field(field_name) + if not isinstance(field, FSMFieldMixin): + raise LookupError(f"{field_name} is not an FSMField") # noqa: TRY004 + return (field, model) -def node_name(field, state) -> str: +def node_name(field: FSMFieldMixin, state: _StateValue) -> str: opts = field.model._meta + assert opts.verbose_name return "{}.{}.{}.{}".format( opts.app_label, opts.verbose_name.replace(" ", "_"), field.name, state ) -def node_label(field, state: str | None) -> str: - if isinstance(state, (int, bool)) and hasattr(field, "choices") and field.choices: +def node_label(field: FSMFieldMixin, state: _StateValue | None) -> str: + if hasattr(field, "choices") and field.choices: state = dict(field.choices).get(state) return force_str(state) -def generate_dot(fields_data, ignore_transitions: list[str] | None = None): # noqa: C901, PLR0912 +def generate_dot( # noqa: C901, PLR0912 + fields_data: Sequence[tuple[FSMFieldMixin, type[models.Model]]], + ignore_transitions: Sequence[str] | None = None, +) -> graphviz.Digraph: ignore_transitions = ignore_transitions or [] result = graphviz.Digraph() for field, model in fields_data: - sources, targets, edges, any_targets, any_except_targets = set(), set(), set(), set(), set() + sources: set[tuple[(str, str)]] = set() + targets: set[tuple[str, str]] = set() + edges: set[tuple[str, str, tuple[tuple[str, str]]]] = set() + any_targets: set[tuple[_StateValue, str]] = set() + any_except_targets: set[tuple[_StateValue, str]] = set() # dump nodes and edges for transition in field.get_all_transitions(model): @@ -57,6 +81,7 @@ def generate_dot(fields_data, ignore_transitions: list[str] | None = None): # n if isinstance(transition.source, (GET_STATE, RETURN_VALUE)) else ((transition.source, node_name(field, transition.source)),) ) + for source, source_name in source_name_pair: if transition.on_error: on_error_name = node_name(field, transition.on_error) @@ -69,16 +94,10 @@ def generate_dot(fields_data, ignore_transitions: list[str] | None = None): # n elif transition.source == "+": any_except_targets.add((target, transition.name)) else: - add_transition( - source, - target, - transition.name, - source_name, - field, - sources, - targets, - edges, - ) + target_name = node_name(field, target) + sources.add((source_name, node_label(field, source))) + targets.add((target_name, node_label(field, target))) + edges.add((source_name, target_name, (("label", transition.name),))) targets.update( { @@ -111,50 +130,27 @@ def generate_dot(fields_data, ignore_transitions: list[str] | None = None): # n final_states = targets - sources for name, label in final_states: subgraph.node(name, label=label, shape="doublecircle") + for name, label in (sources | targets) - final_states: subgraph.node(name, label=label, shape="circle") # Adding initial state notation if field.default and label == field.default: initial_name = node_name(field, "_initial") subgraph.node(name=initial_name, label="", shape="point") - subgraph.edge(initial_name, name) + subgraph.edge(tail_name=initial_name, head_name=name) + for source_name, target_name, attrs in edges: - subgraph.edge(source_name, target_name, **dict(attrs)) + subgraph.edge(tail_name=source_name, head_name=target_name, **dict(attrs)) result.subgraph(subgraph) return result -def add_transition( - transition_source, - transition_target, - transition_name, - source_name, - field, - sources, - targets, - edges, -): - target_name = node_name(field, transition_target) - sources.add((source_name, node_label(field, transition_source))) - targets.add((target_name, node_label(field, transition_target))) - edges.add((source_name, target_name, (("label", transition_name),))) - - -def get_graphviz_layouts(): - try: - import graphviz - except ModuleNotFoundError: - return {"sfdp", "circo", "twopi", "dot", "neato", "fdp", "osage", "patchwork"} - else: - return graphviz.ENGINES - - class Command(BaseCommand): help = "Creates a GraphViz dot file with transitions for selected fields" - def add_arguments(self, parser): + def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument( "--output", "-o", @@ -169,7 +165,7 @@ def add_arguments(self, parser): action="store", dest="layout", default="dot", - help=f"Layout to be used by GraphViz for visualization: {get_graphviz_layouts()}.", + help=f"Layout to be used by GraphViz for visualization: {graphviz.ENGINES}.", ) parser.add_argument( "--exclude", @@ -181,22 +177,15 @@ def add_arguments(self, parser): ) parser.add_argument("args", nargs="*", help=("[appname[.model[.field]]]")) - def render_output(self, graph, **options): - filename, graph_format = options["outputfile"].rsplit(".", 1) - - graph.engine = options["layout"] - graph.format = graph_format - graph.render(filename) - - def handle(self, *args, **options): - fields_data = [] + def handle(self, *args: str, **options: typing.Any) -> None: + fields_data: list[tuple[FSMFieldMixin, type[models.Model]]] = [] if len(args) != 0: for arg in args: field_spec = arg.split(".") if len(field_spec) == 1: app = apps.get_app_config(field_spec[0]) - for model in apps.get_models(app): + for model in app.get_models(): fields_data += all_fsm_fields_data(model) if len(field_spec) == 2: # noqa: PLR2004 model = apps.get_model(field_spec[0], field_spec[1]) @@ -210,7 +199,11 @@ def handle(self, *args, **options): dotdata = generate_dot(fields_data, ignore_transitions=options["exclude"].split(",")) - if options["outputfile"]: - self.render_output(dotdata, **options) + if outputfile := options["outputfile"]: + filename, graph_format = outputfile.rsplit(".", 1) + + dotdata.engine = options["layout"] + dotdata.format = graph_format + dotdata.render(filename) else: self.stdout.write(str(dotdata)) diff --git a/django_fsm/migrations/0001_initial.py b/django_fsm/migrations/0001_initial.py new file mode 100644 index 0000000..2f6f77a --- /dev/null +++ b/django_fsm/migrations/0001_initial.py @@ -0,0 +1,66 @@ +# Generated by Django 4.2.16 on 2026-01-29 08:32 +from __future__ import annotations + +import django.db.models.deletion +import django.utils.timezone +from django.conf import settings +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("contenttypes", "0002_remove_content_type_name"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="StateLog", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("timestamp", models.DateTimeField(default=django.utils.timezone.now)), + ("state_field", models.CharField(max_length=255)), + ( + "source_state", + models.CharField(blank=True, default=None, max_length=255, null=True), + ), + ("state", models.CharField(max_length=255, verbose_name="Target state")), + ("transition", models.CharField(max_length=255)), + ("description", models.TextField(blank=True, null=True)), + ("object_id", models.TextField()), + ( + "by", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "content_type", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="contenttypes.contenttype" + ), + ), + ], + options={ + "indexes": [ + models.Index( + fields=["source_state", "state"], name="django_fsm__source__bb71ee_idx" + ), + models.Index( + fields=["content_type", "object_id"], name="django_fsm__content_9593e3_idx" + ), + ], + }, + ), + ] diff --git a/django_fsm/migrations/__init__.py b/django_fsm/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/django_fsm/models.py b/django_fsm/models.py new file mode 100644 index 0000000..202a0a5 --- /dev/null +++ b/django_fsm/models.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import typing +from warnings import warn + +from django.conf import settings +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.db import models +from django.utils.timezone import now + + +class TransitionLogBase(models.Model): + timestamp = models.DateTimeField(default=now) + by = models.ForeignKey( + settings.AUTH_USER_MODEL, + null=True, + blank=True, + on_delete=models.SET_NULL, + ) + state_field = models.CharField(max_length=255) + source_state = models.CharField(max_length=255, null=True, blank=True, default=None) # noqa: DJ001 + state = models.CharField("Target state", max_length=255) + transition = models.CharField(max_length=255) + + description = models.TextField(null=True, blank=True) # noqa: DJ001 + + class Meta: + abstract = True + get_latest_by = "timestamp" + + +class StateLogQuerySet(models.QuerySet["StateLog"]): + def _get_content_type(self, obj: models.Model) -> ContentType: + return ContentType.objects.get_for_model(obj) + + def for_(self, obj: models.Model) -> typing.Self: + return self.filter(content_type=self._get_content_type(obj), object_id=obj.pk) + + +class StateLog(TransitionLogBase): # noqa: DJ008 + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + object_id = models.TextField() + content_object = GenericForeignKey("content_type", "object_id") + + objects = StateLogQuerySet.as_manager() + + class Meta: + indexes = [ + models.Index(fields=["source_state", "state"]), + models.Index(fields=["content_type", "object_id"]), + ] + + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + warn( + "StateLog model has been deprecated, you should now bring your own model." + "\nPlease check the documentation to know how to.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/django_fsm/py.typed b/django_fsm/py.typed new file mode 100644 index 0000000..1841ff3 --- /dev/null +++ b/django_fsm/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 typed packages. diff --git a/django_fsm/signals.py b/django_fsm/signals.py index 60a8e02..d63a2c2 100644 --- a/django_fsm/signals.py +++ b/django_fsm/signals.py @@ -2,5 +2,5 @@ from django.db.models.signals import ModelSignal -pre_transition = ModelSignal() -post_transition = ModelSignal() +pre_transition: ModelSignal = ModelSignal() +post_transition: ModelSignal = ModelSignal() diff --git a/pyproject.toml b/pyproject.toml index 20bdcec..5537a04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "django-fsm-2" version = "4.1.0" description = "Django friendly finite state machine support." authors = [{ name = "Mikhail Podgurskiy", email = "kmmbvnr@gmail.com" }] -requires-python = "~=3.8" +requires-python = ">=3.8" readme = "README.md" license = "MIT" classifiers = [ @@ -29,7 +29,10 @@ classifiers = [ "Programming Language :: Python :: 3.14", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["django>=4.2"] +dependencies = [ + "django>=4.2", + "typing-extensions>=4.13.2", +] [project.urls] Homepage = "http://github.com/django-commons/django-fsm-2" @@ -59,10 +62,10 @@ default-groups = [ ] [tool.hatch.build.targets.sdist] -include = ["django_fsm"] +include = ["django_fsm", "django_fsm/py.typed"] [tool.hatch.build.targets.wheel] -include = ["django_fsm"] +include = ["django_fsm", "django_fsm/py.typed"] [build-system] requires = ["hatchling"] @@ -100,6 +103,7 @@ extend-ignore = [ ] fixable = [ "I", # isort + "RUF022", "RUF100", # Unused `noqa` directive "E501", ] @@ -112,3 +116,46 @@ fixable = [ [tool.ruff.lint.isort] force-single-line = true required-imports = ["from __future__ import annotations"] + +[tool.django-stubs] +django_settings_module = "tests.settings" + +[tool.mypy] +python_version = 3.12 +plugins = ["mypy_django_plugin.main"] + +# Start off with these +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true + +# Getting these passing should be easy +strict_equality = true +extra_checks = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true + +# These next few are various gradations of forcing use of type annotations +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +[[tool.mypy.overrides]] +module = [ + "tests.*", + "django_fsm.tests.*" +] +disallow_untyped_defs = false diff --git a/tests/manage.py b/tests/manage.py index d120277..69a4fbf 100755 --- a/tests/manage.py +++ b/tests/manage.py @@ -7,7 +7,7 @@ import sys -def main(): +def main() -> None: """Run administrative tasks.""" os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") try: diff --git a/tests/settings.py b/tests/settings.py index 295ae84..87c9d0b 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -27,7 +27,7 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -ALLOWED_HOSTS = [] +ALLOWED_HOSTS: list[str] = [] # Application definition diff --git a/tests/testapp/models.py b/tests/testapp/models.py index ba07a18..dff13fc 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing + from django.db import models from django_fsm import GET_STATE @@ -7,6 +9,13 @@ from django_fsm import FSMField from django_fsm import FSMKeyField from django_fsm import transition +from django_fsm.log import fsm_log_by +from django_fsm.log import fsm_log_description +from django_fsm.log import track +from django_fsm.models import TransitionLogBase + +if typing.TYPE_CHECKING: + from django.contrib.auth.models import AbstractUser class Application(models.Model): @@ -17,20 +26,20 @@ class Application(models.Model): state = FSMField(default="new") - @transition(field=state, source="new", target="published") - def standard(self): + @transition(field=state, source="new", target="published", on_error="failed") + def standard(self) -> None: pass @transition(field=state, source="published") - def no_target(self): + def no_target(self) -> None: pass @transition(field=state, source="*", target="blocked") - def any_source(self): + def any_source(self) -> None: pass @transition(field=state, source="+", target="hidden") - def any_source_except_target(self): + def any_source_except_target(self) -> None: pass @transition( @@ -41,7 +50,7 @@ def any_source_except_target(self): states=["published", "rejected"], ), ) - def get_state(self, *, allowed: bool): + def get_state(self, *, allowed: bool) -> None: pass @transition( @@ -52,7 +61,7 @@ def get_state(self, *, allowed: bool): states=["published", "rejected"], ), ) - def get_state_any_source(self, *, allowed: bool): + def get_state_any_source(self, *, allowed: bool) -> None: pass @transition( @@ -63,23 +72,23 @@ def get_state_any_source(self, *, allowed: bool): states=["published", "rejected"], ), ) - def get_state_any_source_except_target(self, *, allowed: bool): + def get_state_any_source_except_target(self, *, allowed: bool) -> None: pass @transition(field=state, source="new", target=RETURN_VALUE("moderated", "blocked")) - def return_value(self): + def return_value(self) -> str: return "published" @transition(field=state, source="*", target=RETURN_VALUE("moderated", "blocked")) - def return_value_any_source(self): + def return_value_any_source(self) -> str: return "published" @transition(field=state, source="+", target=RETURN_VALUE("moderated", "blocked")) - def return_value_any_source_except_target(self): + def return_value_any_source_except_target(self) -> str: return "published" @transition(field=state, source="new", target="published", on_error="failed") - def on_error(self): + def on_error(self) -> None: pass @@ -105,19 +114,19 @@ class FKApplication(models.Model): state = FSMKeyField(DbState, default="new", on_delete=models.CASCADE) @transition(field=state, source="new", target="published") - def standard(self): + def standard(self) -> None: pass @transition(field=state, source="published") - def no_target(self): + def no_target(self) -> None: pass @transition(field=state, source="*", target="blocked") - def any_source(self): + def any_source(self) -> None: pass @transition(field=state, source="+", target="hidden") - def any_source_except_target(self): + def any_source_except_target(self) -> None: pass @transition( @@ -128,7 +137,7 @@ def any_source_except_target(self): states=["published", "rejected"], ), ) - def get_state(self, *, allowed: bool): + def get_state(self, *, allowed: bool) -> None: pass @transition( @@ -139,7 +148,7 @@ def get_state(self, *, allowed: bool): states=["published", "rejected"], ), ) - def get_state_any_source(self, *, allowed: bool): + def get_state_any_source(self, *, allowed: bool) -> None: pass @transition( @@ -150,23 +159,23 @@ def get_state_any_source(self, *, allowed: bool): states=["published", "rejected"], ), ) - def get_state_any_source_except_target(self, *, allowed: bool): + def get_state_any_source_except_target(self, *, allowed: bool) -> None: pass @transition(field=state, source="new", target=RETURN_VALUE("moderated", "blocked")) - def return_value(self): + def return_value(self) -> str: return "published" @transition(field=state, source="*", target=RETURN_VALUE("moderated", "blocked")) - def return_value_any_source(self): + def return_value_any_source(self) -> str: return "published" @transition(field=state, source="+", target=RETURN_VALUE("moderated", "blocked")) - def return_value_any_source_except_target(self): + def return_value_any_source_except_target(self) -> str: return "published" @transition(field=state, source="new", target="published", on_error="failed") - def on_error(self): + def on_error(self) -> None: pass @@ -174,7 +183,7 @@ class MultiStateApplication(Application): another_state = FSMKeyField(DbState, default="new", on_delete=models.CASCADE) @transition(field=another_state, source="new", target="published") - def another_state_standard(self): + def another_state_standard(self) -> None: pass @@ -202,8 +211,8 @@ class Meta: ("can_remove_post", "Can remove post"), ] - def can_restore(self, user): - return user.is_superuser or user.is_staff + def can_restore(self: models.Model, user: AbstractUser) -> bool: + return bool(user.is_superuser or user.is_staff) @transition( field=state, @@ -212,11 +221,11 @@ def can_restore(self, user): on_error=BlogPostState.FAILED, permission="testapp.can_publish_post", ) - def publish(self): + def publish(self) -> None: pass @transition(field=state, source=BlogPostState.PUBLISHED) - def notify_all(self): + def notify_all(self) -> None: pass @transition( @@ -225,7 +234,7 @@ def notify_all(self): target=BlogPostState.HIDDEN, on_error=BlogPostState.FAILED, ) - def hide(self): + def hide(self) -> None: pass @transition( @@ -235,7 +244,7 @@ def hide(self): on_error=BlogPostState.FAILED, permission=lambda _, u: u.has_perm("testapp.can_remove_post"), ) - def remove(self): + def remove(self) -> None: raise Exception(f"No rights to delete {self}") @transition( @@ -245,7 +254,7 @@ def remove(self): on_error=BlogPostState.FAILED, permission=can_restore, ) - def restore(self): + def restore(self) -> None: pass @transition( @@ -253,9 +262,39 @@ def restore(self): source=[BlogPostState.PUBLISHED, BlogPostState.HIDDEN], target=BlogPostState.STOLEN, ) - def steal(self): + def steal(self) -> None: pass @transition(field=state, source="*", target=BlogPostState.MODERATED) - def moderate(self): + def moderate(self) -> None: + pass + + +class TrackedPostLog(TransitionLogBase): + post = models.ForeignKey( + "TrackedPost", + on_delete=models.CASCADE, + related_name="transition_logs", + ) + + +@track(log_model=TrackedPostLog, relation_field="post") +class TrackedPost(models.Model): + state = FSMField(default="new") + + @fsm_log_by + @fsm_log_description + @transition(field=state, source="new", target="published") + def publish(self, by=None, description=None, **kwargs): + pass + + +@track() +class GenericTrackedPost(models.Model): + state = FSMField(default="new") + + @fsm_log_by + @fsm_log_description + @transition(field=state, source="new", target="published") + def publish(self, by=None, description=None, **kwargs): pass diff --git a/tests/testapp/tests/test_abstract_inheritance.py b/tests/testapp/tests/test_abstract_inheritance.py index 1d40616..0e7a9af 100644 --- a/tests/testapp/tests/test_abstract_inheritance.py +++ b/tests/testapp/tests/test_abstract_inheritance.py @@ -53,11 +53,11 @@ def test_known_transition_should_succeed(self): def test_field_available_transitions_works(self): self.model.publish() assert self.model.state == "published" - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] assert [data.target for data in transitions] == ["sticked"] def test_field_all_transitions_works(self): - transitions = self.model.get_all_state_transitions() + transitions = self.model.get_all_state_transitions() # type: ignore[attr-defined] assert {("new", "published"), ("published", "sticked")} == { (data.source, data.target) for data in transitions } diff --git a/tests/testapp/tests/test_access_deferred_fsm_field.py b/tests/testapp/tests/test_access_deferred_fsm_field.py index 9c147b7..ed4a516 100644 --- a/tests/testapp/tests/test_access_deferred_fsm_field.py +++ b/tests/testapp/tests/test_access_deferred_fsm_field.py @@ -11,6 +11,8 @@ class DeferrableModel(models.Model): state = FSMField(default="new") + objects: models.Manager[DeferrableModel] = models.Manager() + @transition(field=state, source="new", target="published") def publish(self): pass diff --git a/tests/testapp/tests/test_basic_transitions.py b/tests/testapp/tests/test_basic_transitions.py index 33beb9b..0bedb78 100644 --- a/tests/testapp/tests/test_basic_transitions.py +++ b/tests/testapp/tests/test_basic_transitions.py @@ -167,7 +167,7 @@ def setUp(self): def test_in_operator_for_available_transitions(self): # store the generator in a list, so we can reuse the generator and do multiple asserts - transitions = list(self.model.get_available_state_transitions()) + transitions = list(self.model.get_available_state_transitions()) # type: ignore[attr-defined] assert "publish" in transitions assert "xyz" not in transitions @@ -181,15 +181,15 @@ def publish(): source="", target="", on_error="", - conditions="", + conditions=None, permission="", - custom="", + custom=None, ) assert obj in transitions def test_available_conditions_from_new(self): - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = { ("*", "moderated"), @@ -202,7 +202,7 @@ def test_available_conditions_from_new(self): def test_available_conditions_from_published(self): self.model.publish() - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = { ("*", "moderated"), @@ -217,7 +217,7 @@ def test_available_conditions_from_published(self): def test_available_conditions_from_hidden(self): self.model.publish() self.model.hide() - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = {("*", "moderated"), ("hidden", "stolen"), ("*", ""), ("+", "blocked")} assert actual == expected @@ -225,27 +225,27 @@ def test_available_conditions_from_hidden(self): def test_available_conditions_from_stolen(self): self.model.publish() self.model.steal() - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = {("*", "moderated"), ("*", ""), ("+", "blocked")} assert actual == expected def test_available_conditions_from_blocked(self): self.model.block() - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = {("*", "moderated"), ("*", "")} assert actual == expected def test_available_conditions_from_empty(self): self.model.empty() - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = {("*", "moderated"), ("*", ""), ("+", "blocked")} assert actual == expected def test_all_conditions(self): - transitions = self.model.get_all_state_transitions() + transitions = self.model.get_all_state_transitions() # type: ignore[attr-defined] actual = {(transition.source, transition.target) for transition in transitions} expected = { diff --git a/tests/testapp/tests/test_conditions.py b/tests/testapp/tests/test_conditions.py index e605ecb..5562e3b 100644 --- a/tests/testapp/tests/test_conditions.py +++ b/tests/testapp/tests/test_conditions.py @@ -10,17 +10,17 @@ from django_fsm import transition -def condition_func(instance): +def condition_func(instance: models.Model) -> bool: return True class BlogPostWithConditions(models.Model): state = FSMField(default="new") - def model_condition(self): + def model_condition(self: models.Model) -> bool: return True - def unmet_condition(self): + def unmet_condition(self: models.Model) -> bool: return False @transition( diff --git a/tests/testapp/tests/test_custom_data.py b/tests/testapp/tests/test_custom_data.py index f4397af..7ccd687 100644 --- a/tests/testapp/tests/test_custom_data.py +++ b/tests/testapp/tests/test_custom_data.py @@ -45,13 +45,13 @@ def setUp(self): def test_initial_state(self): assert self.model.state == "new" - transitions = list(self.model.get_available_state_transitions()) + transitions = list(self.model.get_available_state_transitions()) # type: ignore[attr-defined] assert len(transitions) == 1 assert transitions[0].target == "published" assert transitions[0].custom == {"label": "Publish", "type": "*"} def test_all_transitions_have_custom_data(self): - transitions = self.model.get_all_state_transitions() + transitions = self.model.get_all_state_transitions() # type: ignore[attr-defined] for t in transitions: assert t.custom["label"] is not None assert t.custom["type"] is not None diff --git a/tests/testapp/tests/test_exception_transitions.py b/tests/testapp/tests/test_exception_transitions.py index 7165eda..97a26e1 100644 --- a/tests/testapp/tests/test_exception_transitions.py +++ b/tests/testapp/tests/test_exception_transitions.py @@ -26,7 +26,7 @@ class FSMFieldExceptionTest(TestCase): def setUp(self): self.model = ExceptionalBlogPost() post_transition.connect(self.on_post_transition, sender=ExceptionalBlogPost) - self.post_transition_data = None + self.post_transition_data = {} def on_post_transition(self, **kwargs): self.post_transition_data = kwargs @@ -44,4 +44,4 @@ def test_state_not_changed_after_fail(self): with pytest.raises(Exception, match="Upss"): self.model.delete() assert self.model.state == "new" - assert self.post_transition_data is None + assert self.post_transition_data == {} diff --git a/tests/testapp/tests/test_graph_transitions.py b/tests/testapp/tests/test_graph_transitions.py index 777454f..dd90fa9 100644 --- a/tests/testapp/tests/test_graph_transitions.py +++ b/tests/testapp/tests/test_graph_transitions.py @@ -2,18 +2,23 @@ import os import tempfile +import typing from io import StringIO from pathlib import Path +import graphviz import pytest from django.core.exceptions import FieldDoesNotExist from django.core.management import call_command from django.test import TestCase -from django_fsm.management.commands.graph_transitions import get_graphviz_layouts from django_fsm.management.commands.graph_transitions import node_label +from django_fsm.management.commands.graph_transitions import node_name +from tests.testapp.models import Application from tests.testapp.models import BlogPost from tests.testapp.models import BlogPostState +from tests.testapp.tests.test_model_create_with_generic import Task +from tests.testapp.tests.test_model_create_with_generic import TaskState class GraphTransitionsCommandTest(TestCase): @@ -24,13 +29,20 @@ class GraphTransitionsCommandTest(TestCase): EXTENSIONS_TO_TEST = ["png", "jpg", "jpeg"] + def test_node_name(self): + assert node_name(Task.state.field, TaskState.DONE) == "testapp.task.state.done" + assert node_name(BlogPost.state.field, BlogPostState.NEW) == "testapp.blog_post.state.0" + def test_node_label(self): + assert node_label(Application.state.field, "new") == "new" assert ( node_label(BlogPost.state.field, BlogPostState.PUBLISHED.value) == BlogPostState.PUBLISHED.label ) + # choices is not declared, fallbacking to the value instead + assert node_label(Task.state.field, TaskState.DONE.value) == TaskState.DONE.value - def _call_command(self, *args, **kwargs): + def _call_command(self, *args: typing.Any, **kwargs: typing.Any) -> str: out = StringIO() call_command("graph_transitions", *args, **kwargs, stdout=out) return out.getvalue() @@ -59,7 +71,7 @@ def test_single_model_fail(self): def test_single_model_with_layouts(self): for model in self.MODELS_TO_TEST: - for layout in get_graphviz_layouts(): + for layout in graphviz.ENGINES: self._call_command("-l", layout, model) def test_single_model_with_output(self): @@ -95,3 +107,358 @@ def test_single_field(self): def test_single_field_fail(self): with pytest.raises((LookupError, FieldDoesNotExist)): self._call_command("testapp.MultiStateApplication.unknown_field") + + with pytest.raises(LookupError): + self._call_command("testapp.MultiStateApplication.id") + + def test_output_contains_subgraph_label(self): # noqa: PLR0915 + output = self._call_command("testapp.Application") + + assert "subgraph cluster_testapp_Application_state {" in output + assert 'graph [label="testapp.Application.state"]' in output + assert '"testapp.application.state.new" [label=new shape=circle]' in output + assert '"testapp.application.state._initial" [label="" shape=point]' in output + assert '"testapp.application.state._initial" -> "testapp.application.state.new"' in output + assert '"testapp.application.state.failed" [label=failed shape=circle]' in output + assert '"testapp.application.state.None" [label=None shape=circle]' in output + assert '"testapp.application.state.blocked" [label=blocked shape=circle]' in output + assert '"testapp.application.state.hidden" [label=hidden shape=circle]' in output + assert '"testapp.application.state.rejected" [label=rejected shape=circle]' in output + assert '"testapp.application.state.moderated" [label=moderated shape=circle]' in output + assert '"testapp.application.state.published" [label=published shape=circle]' in output + assert ( + '"testapp.application.state.new" -> "testapp.application.state.rejected" [label=get_state]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.published" [label=on_error]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.moderated" [label=return_value]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.blocked" [label=return_value]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.None" [label=no_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.failed" [style=dotted]' + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.published" [label=get_state]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.moderated" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.failed" [style=dotted]' + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.published" [label=standard]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.hidden" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.blocked" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.new" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.rejected" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.published" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.rejected" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.published" [label=get_state_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.None" -> "testapp.application.state.rejected" [label=get_state_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.blocked" [label=any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.failed" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.moderated" -> "testapp.application.state.blocked" [label=return_value_any_source]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.blocked" -> "testapp.application.state.moderated" [label=return_value_any_source_except_target]' # noqa: E501 + in output + ) + assert ( + '"testapp.application.state.published" -> "testapp.application.state.hidden" [label=any_source_except_target]' # noqa: E501 + in output + ) diff --git a/tests/testapp/tests/test_lock_mixin.py b/tests/testapp/tests/test_lock_mixin.py index fc3cb6b..8dc17a1 100644 --- a/tests/testapp/tests/test_lock_mixin.py +++ b/tests/testapp/tests/test_lock_mixin.py @@ -14,6 +14,8 @@ class LockedBlogPost(ConcurrentTransitionMixin, models.Model): state = FSMField(default="new") text = models.CharField(max_length=50) + objects: models.Manager[LockedBlogPost] = models.Manager() + @transition(field=state, source="new", target="published") def publish(self): pass @@ -27,6 +29,8 @@ class ExtendedBlogPost(LockedBlogPost): review_state = FSMField(default="waiting", protected=True) notes = models.CharField(max_length=50) + objects: models.Manager[ExtendedBlogPost] = models.Manager() + @transition(field=review_state, source="waiting", target="rejected") def reject(self): pass diff --git a/tests/testapp/tests/test_model_create_with_generic.py b/tests/testapp/tests/test_model_create_with_generic.py index b811211..2b71792 100644 --- a/tests/testapp/tests/test_model_create_with_generic.py +++ b/tests/testapp/tests/test_model_create_with_generic.py @@ -9,7 +9,8 @@ from django_fsm import transition -class Ticket(models.Model): ... +class Ticket(models.Model): + objects: models.Manager[Ticket] = models.Manager() class TaskState(models.TextChoices): @@ -23,6 +24,8 @@ class Task(models.Model): causality = GenericForeignKey("content_type", "object_id") state = FSMField(default=TaskState.NEW) + objects: models.Manager[Task] = models.Manager() + @transition(field=state, source=TaskState.NEW, target=TaskState.DONE) def do(self): pass diff --git a/tests/testapp/tests/test_object_permissions.py b/tests/testapp/tests/test_object_permissions.py index f168d0c..41fac3b 100644 --- a/tests/testapp/tests/test_object_permissions.py +++ b/tests/testapp/tests/test_object_permissions.py @@ -14,6 +14,8 @@ class ObjectPermissionTestModel(models.Model): state = FSMField(default="new") + objects: models.Manager[ObjectPermissionTestModel] = models.Manager() + class Meta: permissions = [ ("can_publish_objectpermissiontestmodel", "Can publish ObjectPermissionTestModel"), @@ -26,7 +28,7 @@ class Meta: on_error="failed", permission="testapp.can_publish_objectpermissiontestmodel", ) - def publish(self): + def publish(self) -> None: pass diff --git a/tests/testapp/tests/test_permissions.py b/tests/testapp/tests/test_permissions.py index db1f163..84ddf05 100644 --- a/tests/testapp/tests/test_permissions.py +++ b/tests/testapp/tests/test_permissions.py @@ -26,14 +26,14 @@ def test_privileged_access_succeed(self): assert has_transition_perm(self.model.publish, self.privileged) assert has_transition_perm(self.model.remove, self.privileged) - transitions = self.model.get_available_user_state_transitions(self.privileged) + transitions = self.model.get_available_user_state_transitions(self.privileged) # type: ignore[attr-defined] assert {"publish", "remove", "moderate"} == {transition.name for transition in transitions} def test_unprivileged_access_prohibited(self): assert not has_transition_perm(self.model.publish, self.unprivileged) assert not has_transition_perm(self.model.remove, self.unprivileged) - transitions = self.model.get_available_user_state_transitions(self.unprivileged) + transitions = self.model.get_available_user_state_transitions(self.unprivileged) # type: ignore[attr-defined] assert {"moderate"} == {transition.name for transition in transitions} def test_permission_instance_method(self): diff --git a/tests/testapp/tests/test_protected_field.py b/tests/testapp/tests/test_protected_field.py index f595818..1ccbf44 100644 --- a/tests/testapp/tests/test_protected_field.py +++ b/tests/testapp/tests/test_protected_field.py @@ -11,6 +11,8 @@ class ProtectedAccessModel(models.Model): status = FSMField(default="new", protected=True) + objects: models.Manager[ProtectedAccessModel] = models.Manager() + @transition(field=status, source="new", target="published") def publish(self): pass @@ -20,6 +22,8 @@ class MultiProtectedAccessModel(models.Model): status1 = FSMField(default="new", protected=True) status2 = FSMField(default="new", protected=True) + objects: models.Manager[MultiProtectedAccessModel] = models.Manager() + class TestDirectAccessModels(TestCase): def test_multi_protected_field_create(self): @@ -31,7 +35,7 @@ def test_no_direct_access(self): instance = ProtectedAccessModel() assert instance.status == "new" - def try_change(): + def try_change() -> None: instance.status = "change" with pytest.raises(AttributeError): diff --git a/tests/testapp/tests/test_protected_fields.py b/tests/testapp/tests/test_protected_fields.py index dcdb5b1..a8a784d 100644 --- a/tests/testapp/tests/test_protected_fields.py +++ b/tests/testapp/tests/test_protected_fields.py @@ -12,6 +12,8 @@ class RefreshableProtectedAccessModel(models.Model): status = FSMField(default="new", protected=True) + objects: models.Manager[RefreshableProtectedAccessModel] = models.Manager() + @transition(field=status, source="new", target="published") def publish(self): pass diff --git a/tests/testapp/tests/test_proxy_inheritance.py b/tests/testapp/tests/test_proxy_inheritance.py index 5376be1..82d81f8 100644 --- a/tests/testapp/tests/test_proxy_inheritance.py +++ b/tests/testapp/tests/test_proxy_inheritance.py @@ -41,15 +41,15 @@ def test_known_transition_should_succeed(self): def test_field_available_transitions_works(self): self.model.publish() assert self.model.state == "published" - transitions = self.model.get_available_state_transitions() + transitions = self.model.get_available_state_transitions() # type: ignore[attr-defined] assert [data.target for data in transitions] == ["sticked"] def test_field_all_transitions_base_model(self): - transitions = BaseModel().get_all_state_transitions() + transitions = BaseModel().get_all_state_transitions() # type: ignore[attr-defined] assert {("new", "published")} == {(data.source, data.target) for data in transitions} def test_field_all_transitions_works(self): - transitions = self.model.get_all_state_transitions() + transitions = self.model.get_all_state_transitions() # type: ignore[attr-defined] assert {("new", "published"), ("published", "sticked")} == { (data.source, data.target) for data in transitions } diff --git a/tests/testapp/tests/test_state_transitions.py b/tests/testapp/tests/test_state_transitions.py index c1bf1f6..5d233d0 100644 --- a/tests/testapp/tests/test_state_transitions.py +++ b/tests/testapp/tests/test_state_transitions.py @@ -19,6 +19,8 @@ class STATE: state = FSMField(default=STATE.CATERPILLAR, state_choices=STATE_CHOICES) + objects: models.Manager[Insect] = models.Manager() + @transition(field=state, source=STATE.CATERPILLAR, target=STATE.BUTTERFLY) def cocoon(self): pass diff --git a/tests/testapp/tests/test_transition_tracking.py b/tests/testapp/tests/test_transition_tracking.py new file mode 100644 index 0000000..41f0918 --- /dev/null +++ b/tests/testapp/tests/test_transition_tracking.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from django_fsm.log import StateLog +from tests.testapp.models import GenericTrackedPost +from tests.testapp.models import TrackedPost +from tests.testapp.models import TrackedPostLog + + +class TransitionTrackingTests(TestCase): + def test_default_tracking_uses_generic_log(self) -> None: + user = get_user_model().objects.create_user(username="author") + post = GenericTrackedPost.objects.create() + post.publish(by=user, description="published via generic log") + + log = StateLog.objects.for_(post).get(object_id=str(post.pk)) + + assert log.transition == "publish" + assert log.state_field == "state" + assert log.source_state == "new" + assert log.state == "published" + assert log.by == user + assert log.description == "published via generic log" + + def test_custom_tracking_writes_to_model_log(self) -> None: + user = get_user_model().objects.create_user(username="author") + post = TrackedPost.objects.create() + post.publish(by=user, description="published via custom log") + + log = TrackedPostLog.objects.get(post=post) + + assert log.transition == "publish" + assert log.state_field == "state" + assert log.source_state == "new" + assert log.state == "published" + assert log.by == user + assert log.description == "published via custom log" diff --git a/uv.lock b/uv.lock index fa49136..2b8880c 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,6 @@ version = 1 revision = 3 -requires-python = ">=3.8, <4" +requires-python = ">=3.8" resolution-markers = [ "python_full_version >= '3.10'", "python_full_version == '3.9.*'", @@ -179,6 +179,13 @@ version = "4.1.0" source = { editable = "." } dependencies = [ { name = "django" }, + { name = "typing-extensions", version = "4.13.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] + +[package.optional-dependencies] +graphviz = [ + { name = "graphviz" }, ] [package.dev-dependencies] @@ -194,12 +201,14 @@ dev = [ { name = "tox", version = "4.30.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, { name = "tox", version = "4.33.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] -graphviz = [ - { name = "graphviz" }, -] [package.metadata] -requires-dist = [{ name = "django", specifier = ">=4.2" }] +requires-dist = [ + { name = "django", specifier = ">=4.2" }, + { name = "graphviz", marker = "extra == 'graphviz'" }, + { name = "typing-extensions", specifier = ">=4.13.2" }, +] +provides-extras = ["graphviz"] [package.metadata.requires-dev] dev = [ @@ -212,7 +221,6 @@ dev = [ { name = "pytest-django" }, { name = "tox" }, ] -graphviz = [{ name = "graphviz" }] [[package]] name = "django-guardian"