diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py index 9abc1fe5..0aa2f131 100644 --- a/statemachine/engines/base.py +++ b/statemachine/engines/base.py @@ -3,8 +3,7 @@ from typing import TYPE_CHECKING from weakref import proxy -from statemachine.event import BoundEvent - +from ..event import BoundEvent from ..event_data import TriggerData from ..state import State from ..transition import Transition diff --git a/statemachine/event.py b/statemachine/event.py index f80f8112..d8fa511d 100644 --- a/statemachine/event.py +++ b/statemachine/event.py @@ -3,10 +3,12 @@ from typing import List from uuid import uuid4 -from statemachine.utils import run_async_from_sync - +from .callbacks import CallbackGroup from .event_data import TriggerData +from .exceptions import InvalidDefinition from .i18n import _ +from .transition_mixin import AddCallbacksMixin +from .utils import run_async_from_sync if TYPE_CHECKING: from .statemachine import StateMachine @@ -25,7 +27,7 @@ } -class Event(str): +class Event(AddCallbacksMixin, str): """An event is triggers a signal that something has happened. They are send to a state machine and allow the state machine to react. @@ -82,6 +84,18 @@ def __repr__(self): def is_same_event(self, *_args, event: "str | None" = None, **_kwargs) -> bool: return self == event + def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwargs): + if self._transitions is None: + raise InvalidDefinition( + _("Cannot add callback '{}' to an event with no transitions.").format(callback) + ) + return self._transitions._add_callback( + callback=callback, + grouper=grouper, + is_event=is_event, + **kwargs, + ) + def __get__(self, instance, owner): """By implementing this method `Event` can be used as a property descriptor diff --git a/statemachine/events.py b/statemachine/events.py index d7a15c6f..052d053a 100644 --- a/statemachine/events.py +++ b/statemachine/events.py @@ -1,5 +1,4 @@ -from statemachine.event import Event - +from .event import Event from .utils import ensure_iterable diff --git a/statemachine/transition_list.py b/statemachine/transition_list.py index 265a7b6c..e1aa6504 100644 --- a/statemachine/transition_list.py +++ b/statemachine/transition_list.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING -from typing import Callable from typing import Iterable from typing import List from .callbacks import CallbackGroup from .transition import Transition +from .transition_mixin import AddCallbacksMixin from .utils import ensure_iterable if TYPE_CHECKING: @@ -12,7 +12,7 @@ from .state import State -class TransitionList: +class TransitionList(AddCallbacksMixin): """A list-like container of :ref:`transitions` with callback functions.""" def __init__(self, transitions: "Iterable[Transition] | None" = None): @@ -97,80 +97,6 @@ def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwar ) return callback - def __call__(self, f): - return self._add_callback(f, CallbackGroup.ON, is_event=True) - - def before(self, f: Callable): - """Adds a ``before`` :ref:`transition actions` callback to every :ref:`transition` in the - :ref:`TransitionList` instance. - - Args: - f: The ``before`` :ref:`transition actions` callback function to be added. - - Returns: - The `f` callable. - """ - return self._add_callback(f, CallbackGroup.BEFORE) - - def after(self, f: Callable): - """Adds a ``after`` :ref:`transition actions` callback to every :ref:`transition` in the - :ref:`TransitionList` instance. - - Args: - f: The ``after`` :ref:`transition actions` callback function to be added. - - Returns: - The `f` callable. - """ - return self._add_callback(f, CallbackGroup.AFTER) - - def on(self, f: Callable): - """Adds a ``on`` :ref:`transition actions` callback to every :ref:`transition` in the - :ref:`TransitionList` instance. - - Args: - f: The ``on`` :ref:`transition actions` callback function to be added. - - Returns: - The `f` callable. - """ - return self._add_callback(f, CallbackGroup.ON) - - def cond(self, f: Callable): - """Adds a ``cond`` :ref:`guards` callback to every :ref:`transition` in the - :ref:`TransitionList` instance. - - Args: - f: The ``cond`` :ref:`guards` callback function to be added. - - Returns: - The `f` callable. - """ - return self._add_callback(f, CallbackGroup.COND, expected_value=True) - - def unless(self, f: Callable): - """Adds a ``unless`` :ref:`guards` callback with expected value ``False`` to every - :ref:`transition` in the :ref:`TransitionList` instance. - - Args: - f: The ``unless`` :ref:`guards` callback function to be added. - - Returns: - The `f` callable. - """ - return self._add_callback(f, CallbackGroup.COND, expected_value=False) - - def validators(self, f: Callable): - """Adds a :ref:`validators` callback to the :ref:`TransitionList` instance. - - Args: - f: The ``validators`` callback function to be added. - Returns: - The callback function. - - """ - return self._add_callback(f, CallbackGroup.VALIDATOR) - def add_event(self, event: str): """ Adds an event to all transitions in the :ref:`TransitionList` instance. diff --git a/statemachine/transition_mixin.py b/statemachine/transition_mixin.py new file mode 100644 index 00000000..0d3a292e --- /dev/null +++ b/statemachine/transition_mixin.py @@ -0,0 +1,82 @@ +from typing import Callable + +from .callbacks import CallbackGroup + + +class AddCallbacksMixin: + def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwargs): + raise NotImplementedError + + def __call__(self, f): + return self._add_callback(f, CallbackGroup.ON, is_event=True) + + def before(self, f: Callable): + """Adds a ``before`` :ref:`transition actions` callback to every :ref:`transition` in the + :ref:`TransitionList` instance. + + Args: + f: The ``before`` :ref:`transition actions` callback function to be added. + + Returns: + The `f` callable. + """ + return self._add_callback(f, CallbackGroup.BEFORE) + + def after(self, f: Callable): + """Adds a ``after`` :ref:`transition actions` callback to every :ref:`transition` in the + :ref:`TransitionList` instance. + + Args: + f: The ``after`` :ref:`transition actions` callback function to be added. + + Returns: + The `f` callable. + """ + return self._add_callback(f, CallbackGroup.AFTER) + + def on(self, f: Callable): + """Adds a ``on`` :ref:`transition actions` callback to every :ref:`transition` in the + :ref:`TransitionList` instance. + + Args: + f: The ``on`` :ref:`transition actions` callback function to be added. + + Returns: + The `f` callable. + """ + return self._add_callback(f, CallbackGroup.ON) + + def cond(self, f: Callable): + """Adds a ``cond`` :ref:`guards` callback to every :ref:`transition` in the + :ref:`TransitionList` instance. + + Args: + f: The ``cond`` :ref:`guards` callback function to be added. + + Returns: + The `f` callable. + """ + return self._add_callback(f, CallbackGroup.COND, expected_value=True) + + def unless(self, f: Callable): + """Adds a ``unless`` :ref:`guards` callback with expected value ``False`` to every + :ref:`transition` in the :ref:`TransitionList` instance. + + Args: + f: The ``unless`` :ref:`guards` callback function to be added. + + Returns: + The `f` callable. + """ + return self._add_callback(f, CallbackGroup.COND, expected_value=False) + + def validators(self, f: Callable): + """Adds a :ref:`validators` callback to the :ref:`TransitionList` instance. + + Args: + f: The ``validators`` callback function to be added. + Returns: + The callback function. + + """ + return self._add_callback(f, CallbackGroup.VALIDATOR) diff --git a/tests/test_events.py b/tests/test_events.py index 51015eea..a746237a 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -233,6 +233,57 @@ def on_cycle(self, event_data, event: str): assert sm.send("cycle") == "Running cycle from yellow to red" assert sm.send("cycle") == "Running cycle from red to green" + def test_allow_registering_callbacks_using_decorator(self): + class TrafficLightMachine(StateMachine): + "A traffic light machine" + + green = State(initial=True) + yellow = State() + red = State() + + cycle = Event( + green.to(yellow, event="slow_down") + | yellow.to(red, event=["stop"]) + | red.to(green, event=["go"]), + name="Loop", + ) + + @cycle.on + def do_cycle(self, event_data, event: str): + assert event_data.event == event + return ( + f"Running {event} from {event_data.transition.source.id} to " + f"{event_data.transition.target.id}" + ) + + sm = TrafficLightMachine() + + assert sm.send("cycle") == "Running cycle from green to yellow" + + def test_raise_registering_callbacks_using_decorator_if_no_transitions(self): + with pytest.raises(InvalidDefinition, match="event with no transitions"): + + class TrafficLightMachine(StateMachine): + "A traffic light machine" + + green = State(initial=True) + yellow = State() + red = State() + + cycle = Event(name="Loop") + slow_down = Event() + green.to(yellow, event=[cycle, slow_down]) + yellow.to(red, event=[cycle, "stop"]) + red.to(green, event=[cycle, "go"]) + + @cycle.on + def do_cycle(self, event_data, event: str): + assert event_data.event == event + return ( + f"Running {event} from {event_data.transition.source.id} to " + f"{event_data.transition.target.id}" + ) + def test_allow_using_events_as_commands(self): class StartMachine(StateMachine): created = State(initial=True)