Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
python-version: '3.13'

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1
uses: ts-graphviz/setup-graphviz@v2

- name: Install uv
uses: astral-sh/setup-uv@v3
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ Easily iterate over all states:

```py
>>> [s.id for s in sm.states]
['green', 'red', 'yellow']
['green', 'yellow', 'red']

```

Expand Down
27 changes: 14 additions & 13 deletions statemachine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def __init__(
func,
group: CallbackGroup,
is_convention=False,
is_event: bool = False,
cond=None,
priority: CallbackPriority = CallbackPriority.NAMING,
expected_value=None,
):
self.func = func
self.group = group
self.is_convention = is_convention
self.is_event = is_event
self.cond = cond
self.expected_value = expected_value
self.priority = priority
Expand All @@ -88,7 +90,12 @@ def __init__(
elif callable(func):
self.reference = SpecReference.CALLABLE
self.is_bounded = hasattr(func, "__self__")
self.attr_name = func.__name__
self.attr_name = (
func.__name__ if not self.is_event or self.is_bounded else f"_{func.__name__}_"
)
if not self.is_bounded:
func.attr_name = self.attr_name
func.is_event = is_event
else:
self.reference = SpecReference.NAME
self.attr_name = func
Expand All @@ -114,11 +121,6 @@ def __eq__(self, other):
def __hash__(self):
return id(self)

def _update_func(self, func: Callable, attr_name: str):
self.func = func
self.reference = SpecReference.CALLABLE
self.attr_name = attr_name


class SpecListGrouper:
def __init__(self, list: "CallbackSpecList", group: CallbackGroup) -> None:
Expand Down Expand Up @@ -158,7 +160,7 @@ def __init__(self, factory=CallbackSpec):
def __repr__(self):
return f"{type(self).__name__}({self.items!r}, factory={self.factory!r})"

def _add_unbounded_callback(self, func, is_event=False, transitions=None, **kwargs):
def _add_unbounded_callback(self, func, transitions=None, **kwargs):
"""This list was a target for adding a func using decorator
`@<state|event>[.on|before|after|enter|exit]` syntax.

Expand All @@ -181,11 +183,7 @@ def _add_unbounded_callback(self, func, is_event=False, transitions=None, **kwar
event.

"""
spec = self._add(func, **kwargs)
if not getattr(func, "_specs_to_update", None):
func._specs_to_update = set()
if is_event:
func._specs_to_update.add(spec._update_func)
self._add(func, **kwargs)
func._transitions = transitions

return func
Expand All @@ -202,7 +200,10 @@ def grouper(self, group: CallbackGroup) -> SpecListGrouper:
return self._groupers[group]

def _add(self, func, group: CallbackGroup, **kwargs):
spec = self.factory(func, group, **kwargs)
if isinstance(func, CallbackSpec):
spec = func
else:
spec = self.factory(func, group, **kwargs)

if spec in self.items:
return
Expand Down
3 changes: 2 additions & 1 deletion statemachine/engines/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..exceptions import InvalidDefinition
from ..exceptions import TransitionNotAllowed
from ..i18n import _
from ..state import State
from ..transition import Transition

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,7 +83,7 @@ async def processing_loop(self):
async def _trigger(self, trigger_data: TriggerData):
event_data = None
if trigger_data.event == "__initial__":
transition = Transition(None, self.sm._get_initial_state(), event="__initial__")
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
transition._specs.clear()
event_data = EventData(trigger_data=trigger_data, transition=transition)
await self._activate(event_data)
Expand Down
3 changes: 2 additions & 1 deletion statemachine/engines/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..event_data import EventData
from ..event_data import TriggerData
from ..exceptions import TransitionNotAllowed
from ..state import State
from ..transition import Transition

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,7 +86,7 @@ def processing_loop(self):
def _trigger(self, trigger_data: TriggerData):
event_data = None
if trigger_data.event == "__initial__":
transition = Transition(None, self.sm._get_initial_state(), event="__initial__")
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
transition._specs.clear()
event_data = EventData(trigger_data=trigger_data, transition=transition)
self._activate(event_data)
Expand Down
18 changes: 7 additions & 11 deletions statemachine/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict
from typing import List
from typing import Tuple
from uuid import uuid4

from . import registry
from .event import Event
Expand Down Expand Up @@ -179,7 +178,7 @@ def add_inherited(cls, bases):
cls.add_event(event=Event(id=event.id, name=event.name))

def add_from_attributes(cls, attrs): # noqa: C901
for key, value in sorted(attrs.items(), key=lambda pair: pair[0]):
for key, value in attrs.items():
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The library is pre-python3.7 where the dictionary order was not guaranteed. This order hack is not required since 3.7.

if isinstance(value, States):
cls._add_states_from_dict(value)
if isinstance(value, State):
Expand All @@ -195,7 +194,7 @@ def add_from_attributes(cls, attrs): # noqa: C901
),
old_event=value,
)
elif getattr(value, "_specs_to_update", None):
elif getattr(value, "attr_name", None):
cls._add_unbounded_callback(key, value)

def _add_states_from_dict(cls, states):
Expand All @@ -205,13 +204,10 @@ def _add_states_from_dict(cls, states):
def _add_unbounded_callback(cls, attr_name, func):
# if func is an event, the `attr_name` will be replaced by an event trigger,
# so we'll also give the ``func`` a new unique name to be used by the callback
# machinery.
cls.add_event(event=Event(func._transitions, id=attr_name, name=attr_name))
attr_name = f"_{attr_name}_{uuid4().hex}"
setattr(cls, attr_name, func)

for ref in func._specs_to_update:
ref(getattr(cls, attr_name), attr_name)
# machinery that is stored at ``func.attr_name``
setattr(cls, func.attr_name, func)
if func.is_event:
cls.add_event(event=Event(func._transitions, id=attr_name, name=attr_name))

def add_state(cls, id, state: State):
state._set_id(id)
Expand All @@ -236,7 +232,7 @@ def add_event(

transitions = event._transitions
if transitions is not None:
transitions.add_event(event)
transitions._on_event_defined(event=event, states=list(cls.states))

if event not in cls._events:
cls._events[event] = None
Expand Down
85 changes: 58 additions & 27 deletions statemachine/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
from weakref import ref

from .callbacks import CallbackGroup
Expand All @@ -15,6 +16,37 @@
from .statemachine import StateMachine


class _TransitionBuilder:
def __init__(self, state: "State"):
self._state = state

def itself(self, **kwargs):
return self.__call__(self._state, **kwargs)

def __call__(self, *states: "State", **kwargs):
raise NotImplementedError


class _ToState(_TransitionBuilder):
def __call__(self, *states: "State", **kwargs):
transitions = TransitionList(Transition(self._state, state, **kwargs) for state in states)
self._state.transitions.add_transitions(transitions)
return transitions


class _FromState(_TransitionBuilder):
def any(self, **kwargs):
return self.__call__(AnyState(), **kwargs)

def __call__(self, *states: "State", **kwargs):
transitions = TransitionList()
for origin in states:
transition = Transition(origin, self._state, **kwargs)
origin.transitions.add_transitions(transition)
transitions.add_transitions(transition)
return transitions


class State:
"""
A State in a :ref:`StateMachine` describes a particular behavior of the machine.
Expand Down Expand Up @@ -136,6 +168,12 @@ def _setup(self):
self.exit.add("on_exit_state", priority=CallbackPriority.GENERIC, is_convention=True)
self.exit.add(f"on_exit_{self.id}", priority=CallbackPriority.NAMING, is_convention=True)

def _on_event_defined(self, event: str, transition: Transition, states: List["State"]):
"""Called by statemachine factory when an event is defined having a transition
starting from this state.
"""
pass

def __repr__(self):
return (
f"{type(self).__name__}({self.name!r}, id={self.id!r}, value={self.value!r}, "
Expand Down Expand Up @@ -172,38 +210,15 @@ def _set_id(self, id: str):
if not self.name:
self.name = self._id.replace("_", " ").capitalize()

def _to_(self, *states: "State", **kwargs):
transitions = TransitionList(Transition(self, state, **kwargs) for state in states)
self.transitions.add_transitions(transitions)
return transitions

def _from_(self, *states: "State", **kwargs):
transitions = TransitionList()
for origin in states:
transition = Transition(origin, self, **kwargs)
origin.transitions.add_transitions(transition)
transitions.add_transitions(transition)
return transitions

def _get_proxy_method_to_itself(self, method):
def proxy(*states: "State", **kwargs):
return method(*states, **kwargs)

def proxy_to_itself(**kwargs):
return proxy(self, **kwargs)

proxy.itself = proxy_to_itself
return proxy

@property
def to(self):
def to(self) -> _ToState:
"""Create transitions to the given target states."""
return self._get_proxy_method_to_itself(self._to_)
return _ToState(self)

@property
def from_(self):
def from_(self) -> _FromState:
"""Create transitions from the given target states (reversed)."""
return self._get_proxy_method_to_itself(self._from_)
return _FromState(self)

@property
def initial(self):
Expand Down Expand Up @@ -269,3 +284,19 @@ def id(self) -> str:
@property
def is_active(self):
return self._machine().current_state == self


class AnyState(State):
"""A special state that works as a "ANY" placeholder.

It is used as the "From" state of a transtion,
until the state machine class is evaluated.
"""

def _on_event_defined(self, event: str, transition: Transition, states: List[State]):
for state in states:
if state.final:
continue
new_transition = transition._copy_with_args(source=state, event=event)

state.transitions.add_transitions(new_transition)
24 changes: 22 additions & 2 deletions statemachine/transition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from copy import deepcopy
from typing import TYPE_CHECKING

from .callbacks import CallbackGroup
from .callbacks import CallbackPriority
from .callbacks import CallbackSpecList
from .events import Events
from .exceptions import InvalidDefinition

if TYPE_CHECKING:
from .statemachine import State


class Transition:
"""A transition holds reference to the source and target state.
Expand Down Expand Up @@ -32,8 +38,8 @@ class Transition:

def __init__(
self,
source,
target,
source: "State",
target: "State",
event=None,
internal=False,
validators=None,
Expand Down Expand Up @@ -125,3 +131,17 @@ def events(self):

def add_event(self, value):
self._events.add(value)

def _copy_with_args(self, **kwargs):
source = kwargs.pop("source", self.source)
target = kwargs.pop("target", self.target)
event = kwargs.pop("event", self.event)
internal = kwargs.pop("internal", self.internal)
new_transition = Transition(
source=source, target=target, event=event, internal=internal, **kwargs
)
for spec in self._specs:
new_spec = deepcopy(spec)
new_transition._specs.add(new_spec, new_spec.group)

return new_transition
Loading