Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions statemachine/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typing import TYPE_CHECKING
from weakref import proxy

from statemachine.event import BoundEvent

from ..event import BoundEvent
from ..event_data import TriggerData
from ..state import State
from ..transition import Transition
Expand Down
20 changes: 17 additions & 3 deletions statemachine/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import List
from uuid import uuid4

from statemachine.utils import run_async_from_sync

from .callbacks import CallbackGroup
from .event_data import TriggerData
from .exceptions import InvalidDefinition
from .i18n import _
from .transition_mixin import AddCallbacksMixin
from .utils import run_async_from_sync

if TYPE_CHECKING:
from .statemachine import StateMachine
Expand All @@ -25,7 +27,7 @@
}


class Event(str):
class Event(AddCallbacksMixin, str):
"""An event is triggers a signal that something has happened.

They are send to a state machine and allow the state machine to react.
Expand Down Expand Up @@ -82,6 +84,18 @@ def __repr__(self):
def is_same_event(self, *_args, event: "str | None" = None, **_kwargs) -> bool:
return self == event

def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwargs):
if self._transitions is None:
raise InvalidDefinition(
_("Cannot add callback '{}' to an event with no transitions.").format(callback)
)
return self._transitions._add_callback(
callback=callback,
grouper=grouper,
is_event=is_event,
**kwargs,
)

def __get__(self, instance, owner):
"""By implementing this method `Event` can be used as a property descriptor

Expand Down
3 changes: 1 addition & 2 deletions statemachine/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from statemachine.event import Event

from .event import Event
from .utils import ensure_iterable


Expand Down
78 changes: 2 additions & 76 deletions statemachine/transition_list.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import TYPE_CHECKING
from typing import Callable
from typing import Iterable
from typing import List

from .callbacks import CallbackGroup
from .transition import Transition
from .transition_mixin import AddCallbacksMixin
from .utils import ensure_iterable

if TYPE_CHECKING:
from .events import Event
from .state import State


class TransitionList:
class TransitionList(AddCallbacksMixin):
"""A list-like container of :ref:`transitions` with callback functions."""

def __init__(self, transitions: "Iterable[Transition] | None" = None):
Expand Down Expand Up @@ -97,80 +97,6 @@ def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwar
)
return callback

def __call__(self, f):
return self._add_callback(f, CallbackGroup.ON, is_event=True)

def before(self, f: Callable):
"""Adds a ``before`` :ref:`transition actions` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``before`` :ref:`transition actions` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.BEFORE)

def after(self, f: Callable):
"""Adds a ``after`` :ref:`transition actions` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``after`` :ref:`transition actions` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.AFTER)

def on(self, f: Callable):
"""Adds a ``on`` :ref:`transition actions` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``on`` :ref:`transition actions` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.ON)

def cond(self, f: Callable):
"""Adds a ``cond`` :ref:`guards` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``cond`` :ref:`guards` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.COND, expected_value=True)

def unless(self, f: Callable):
"""Adds a ``unless`` :ref:`guards` callback with expected value ``False`` to every
:ref:`transition` in the :ref:`TransitionList` instance.

Args:
f: The ``unless`` :ref:`guards` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.COND, expected_value=False)

def validators(self, f: Callable):
"""Adds a :ref:`validators` callback to the :ref:`TransitionList` instance.

Args:
f: The ``validators`` callback function to be added.
Returns:
The callback function.

"""
return self._add_callback(f, CallbackGroup.VALIDATOR)

def add_event(self, event: str):
"""
Adds an event to all transitions in the :ref:`TransitionList` instance.
Expand Down
82 changes: 82 additions & 0 deletions statemachine/transition_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Callable

from .callbacks import CallbackGroup


class AddCallbacksMixin:
def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwargs):
raise NotImplementedError

def __call__(self, f):
return self._add_callback(f, CallbackGroup.ON, is_event=True)

def before(self, f: Callable):
"""Adds a ``before`` :ref:`transition actions` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``before`` :ref:`transition actions` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.BEFORE)

def after(self, f: Callable):
"""Adds a ``after`` :ref:`transition actions` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``after`` :ref:`transition actions` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.AFTER)

def on(self, f: Callable):
"""Adds a ``on`` :ref:`transition actions` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``on`` :ref:`transition actions` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.ON)

def cond(self, f: Callable):
"""Adds a ``cond`` :ref:`guards` callback to every :ref:`transition` in the
:ref:`TransitionList` instance.

Args:
f: The ``cond`` :ref:`guards` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.COND, expected_value=True)

def unless(self, f: Callable):
"""Adds a ``unless`` :ref:`guards` callback with expected value ``False`` to every
:ref:`transition` in the :ref:`TransitionList` instance.

Args:
f: The ``unless`` :ref:`guards` callback function to be added.

Returns:
The `f` callable.
"""
return self._add_callback(f, CallbackGroup.COND, expected_value=False)

def validators(self, f: Callable):
"""Adds a :ref:`validators` callback to the :ref:`TransitionList` instance.

Args:
f: The ``validators`` callback function to be added.
Returns:
The callback function.

"""
return self._add_callback(f, CallbackGroup.VALIDATOR)
51 changes: 51 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,57 @@ def on_cycle(self, event_data, event: str):
assert sm.send("cycle") == "Running cycle from yellow to red"
assert sm.send("cycle") == "Running cycle from red to green"

def test_allow_registering_callbacks_using_decorator(self):
class TrafficLightMachine(StateMachine):
"A traffic light machine"

green = State(initial=True)
yellow = State()
red = State()

cycle = Event(
green.to(yellow, event="slow_down")
| yellow.to(red, event=["stop"])
| red.to(green, event=["go"]),
name="Loop",
)

@cycle.on
def do_cycle(self, event_data, event: str):
assert event_data.event == event
return (
f"Running {event} from {event_data.transition.source.id} to "
f"{event_data.transition.target.id}"
)

sm = TrafficLightMachine()

assert sm.send("cycle") == "Running cycle from green to yellow"

def test_raise_registering_callbacks_using_decorator_if_no_transitions(self):
with pytest.raises(InvalidDefinition, match="event with no transitions"):

class TrafficLightMachine(StateMachine):
"A traffic light machine"

green = State(initial=True)
yellow = State()
red = State()

cycle = Event(name="Loop")
slow_down = Event()
green.to(yellow, event=[cycle, slow_down])
yellow.to(red, event=[cycle, "stop"])
red.to(green, event=[cycle, "go"])

@cycle.on
def do_cycle(self, event_data, event: str):
assert event_data.event == event
return (
f"Running {event} from {event_data.transition.source.id} to "
f"{event_data.transition.target.id}"
)

def test_allow_using_events_as_commands(self):
class StartMachine(StateMachine):
created = State(initial=True)
Expand Down