Skip to content

Commit 1f91cc4

Browse files
authored
feat: Allow using event as decorator to register callbacks (#506)
1 parent ad0e92e commit 1f91cc4

File tree

6 files changed

+154
-83
lines changed

6 files changed

+154
-83
lines changed

statemachine/engines/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from typing import TYPE_CHECKING
44
from weakref import proxy
55

6-
from statemachine.event import BoundEvent
7-
6+
from ..event import BoundEvent
87
from ..event_data import TriggerData
98
from ..state import State
109
from ..transition import Transition

statemachine/event.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from typing import List
44
from uuid import uuid4
55

6-
from statemachine.utils import run_async_from_sync
7-
6+
from .callbacks import CallbackGroup
87
from .event_data import TriggerData
8+
from .exceptions import InvalidDefinition
99
from .i18n import _
10+
from .transition_mixin import AddCallbacksMixin
11+
from .utils import run_async_from_sync
1012

1113
if TYPE_CHECKING:
1214
from .statemachine import StateMachine
@@ -25,7 +27,7 @@
2527
}
2628

2729

28-
class Event(str):
30+
class Event(AddCallbacksMixin, str):
2931
"""An event is triggers a signal that something has happened.
3032
3133
They are send to a state machine and allow the state machine to react.
@@ -82,6 +84,18 @@ def __repr__(self):
8284
def is_same_event(self, *_args, event: "str | None" = None, **_kwargs) -> bool:
8385
return self == event
8486

87+
def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwargs):
88+
if self._transitions is None:
89+
raise InvalidDefinition(
90+
_("Cannot add callback '{}' to an event with no transitions.").format(callback)
91+
)
92+
return self._transitions._add_callback(
93+
callback=callback,
94+
grouper=grouper,
95+
is_event=is_event,
96+
**kwargs,
97+
)
98+
8599
def __get__(self, instance, owner):
86100
"""By implementing this method `Event` can be used as a property descriptor
87101

statemachine/events.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from statemachine.event import Event
2-
1+
from .event import Event
32
from .utils import ensure_iterable
43

54

statemachine/transition_list.py

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from typing import TYPE_CHECKING
2-
from typing import Callable
32
from typing import Iterable
43
from typing import List
54

65
from .callbacks import CallbackGroup
76
from .transition import Transition
7+
from .transition_mixin import AddCallbacksMixin
88
from .utils import ensure_iterable
99

1010
if TYPE_CHECKING:
1111
from .events import Event
1212
from .state import State
1313

1414

15-
class TransitionList:
15+
class TransitionList(AddCallbacksMixin):
1616
"""A list-like container of :ref:`transitions` with callback functions."""
1717

1818
def __init__(self, transitions: "Iterable[Transition] | None" = None):
@@ -97,80 +97,6 @@ def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwar
9797
)
9898
return callback
9999

100-
def __call__(self, f):
101-
return self._add_callback(f, CallbackGroup.ON, is_event=True)
102-
103-
def before(self, f: Callable):
104-
"""Adds a ``before`` :ref:`transition actions` callback to every :ref:`transition` in the
105-
:ref:`TransitionList` instance.
106-
107-
Args:
108-
f: The ``before`` :ref:`transition actions` callback function to be added.
109-
110-
Returns:
111-
The `f` callable.
112-
"""
113-
return self._add_callback(f, CallbackGroup.BEFORE)
114-
115-
def after(self, f: Callable):
116-
"""Adds a ``after`` :ref:`transition actions` callback to every :ref:`transition` in the
117-
:ref:`TransitionList` instance.
118-
119-
Args:
120-
f: The ``after`` :ref:`transition actions` callback function to be added.
121-
122-
Returns:
123-
The `f` callable.
124-
"""
125-
return self._add_callback(f, CallbackGroup.AFTER)
126-
127-
def on(self, f: Callable):
128-
"""Adds a ``on`` :ref:`transition actions` callback to every :ref:`transition` in the
129-
:ref:`TransitionList` instance.
130-
131-
Args:
132-
f: The ``on`` :ref:`transition actions` callback function to be added.
133-
134-
Returns:
135-
The `f` callable.
136-
"""
137-
return self._add_callback(f, CallbackGroup.ON)
138-
139-
def cond(self, f: Callable):
140-
"""Adds a ``cond`` :ref:`guards` callback to every :ref:`transition` in the
141-
:ref:`TransitionList` instance.
142-
143-
Args:
144-
f: The ``cond`` :ref:`guards` callback function to be added.
145-
146-
Returns:
147-
The `f` callable.
148-
"""
149-
return self._add_callback(f, CallbackGroup.COND, expected_value=True)
150-
151-
def unless(self, f: Callable):
152-
"""Adds a ``unless`` :ref:`guards` callback with expected value ``False`` to every
153-
:ref:`transition` in the :ref:`TransitionList` instance.
154-
155-
Args:
156-
f: The ``unless`` :ref:`guards` callback function to be added.
157-
158-
Returns:
159-
The `f` callable.
160-
"""
161-
return self._add_callback(f, CallbackGroup.COND, expected_value=False)
162-
163-
def validators(self, f: Callable):
164-
"""Adds a :ref:`validators` callback to the :ref:`TransitionList` instance.
165-
166-
Args:
167-
f: The ``validators`` callback function to be added.
168-
Returns:
169-
The callback function.
170-
171-
"""
172-
return self._add_callback(f, CallbackGroup.VALIDATOR)
173-
174100
def add_event(self, event: str):
175101
"""
176102
Adds an event to all transitions in the :ref:`TransitionList` instance.

statemachine/transition_mixin.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Callable
2+
3+
from .callbacks import CallbackGroup
4+
5+
6+
class AddCallbacksMixin:
7+
def _add_callback(self, callback, grouper: CallbackGroup, is_event=False, **kwargs):
8+
raise NotImplementedError
9+
10+
def __call__(self, f):
11+
return self._add_callback(f, CallbackGroup.ON, is_event=True)
12+
13+
def before(self, f: Callable):
14+
"""Adds a ``before`` :ref:`transition actions` callback to every :ref:`transition` in the
15+
:ref:`TransitionList` instance.
16+
17+
Args:
18+
f: The ``before`` :ref:`transition actions` callback function to be added.
19+
20+
Returns:
21+
The `f` callable.
22+
"""
23+
return self._add_callback(f, CallbackGroup.BEFORE)
24+
25+
def after(self, f: Callable):
26+
"""Adds a ``after`` :ref:`transition actions` callback to every :ref:`transition` in the
27+
:ref:`TransitionList` instance.
28+
29+
Args:
30+
f: The ``after`` :ref:`transition actions` callback function to be added.
31+
32+
Returns:
33+
The `f` callable.
34+
"""
35+
return self._add_callback(f, CallbackGroup.AFTER)
36+
37+
def on(self, f: Callable):
38+
"""Adds a ``on`` :ref:`transition actions` callback to every :ref:`transition` in the
39+
:ref:`TransitionList` instance.
40+
41+
Args:
42+
f: The ``on`` :ref:`transition actions` callback function to be added.
43+
44+
Returns:
45+
The `f` callable.
46+
"""
47+
return self._add_callback(f, CallbackGroup.ON)
48+
49+
def cond(self, f: Callable):
50+
"""Adds a ``cond`` :ref:`guards` callback to every :ref:`transition` in the
51+
:ref:`TransitionList` instance.
52+
53+
Args:
54+
f: The ``cond`` :ref:`guards` callback function to be added.
55+
56+
Returns:
57+
The `f` callable.
58+
"""
59+
return self._add_callback(f, CallbackGroup.COND, expected_value=True)
60+
61+
def unless(self, f: Callable):
62+
"""Adds a ``unless`` :ref:`guards` callback with expected value ``False`` to every
63+
:ref:`transition` in the :ref:`TransitionList` instance.
64+
65+
Args:
66+
f: The ``unless`` :ref:`guards` callback function to be added.
67+
68+
Returns:
69+
The `f` callable.
70+
"""
71+
return self._add_callback(f, CallbackGroup.COND, expected_value=False)
72+
73+
def validators(self, f: Callable):
74+
"""Adds a :ref:`validators` callback to the :ref:`TransitionList` instance.
75+
76+
Args:
77+
f: The ``validators`` callback function to be added.
78+
Returns:
79+
The callback function.
80+
81+
"""
82+
return self._add_callback(f, CallbackGroup.VALIDATOR)

tests/test_events.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,57 @@ def on_cycle(self, event_data, event: str):
233233
assert sm.send("cycle") == "Running cycle from yellow to red"
234234
assert sm.send("cycle") == "Running cycle from red to green"
235235

236+
def test_allow_registering_callbacks_using_decorator(self):
237+
class TrafficLightMachine(StateMachine):
238+
"A traffic light machine"
239+
240+
green = State(initial=True)
241+
yellow = State()
242+
red = State()
243+
244+
cycle = Event(
245+
green.to(yellow, event="slow_down")
246+
| yellow.to(red, event=["stop"])
247+
| red.to(green, event=["go"]),
248+
name="Loop",
249+
)
250+
251+
@cycle.on
252+
def do_cycle(self, event_data, event: str):
253+
assert event_data.event == event
254+
return (
255+
f"Running {event} from {event_data.transition.source.id} to "
256+
f"{event_data.transition.target.id}"
257+
)
258+
259+
sm = TrafficLightMachine()
260+
261+
assert sm.send("cycle") == "Running cycle from green to yellow"
262+
263+
def test_raise_registering_callbacks_using_decorator_if_no_transitions(self):
264+
with pytest.raises(InvalidDefinition, match="event with no transitions"):
265+
266+
class TrafficLightMachine(StateMachine):
267+
"A traffic light machine"
268+
269+
green = State(initial=True)
270+
yellow = State()
271+
red = State()
272+
273+
cycle = Event(name="Loop")
274+
slow_down = Event()
275+
green.to(yellow, event=[cycle, slow_down])
276+
yellow.to(red, event=[cycle, "stop"])
277+
red.to(green, event=[cycle, "go"])
278+
279+
@cycle.on
280+
def do_cycle(self, event_data, event: str):
281+
assert event_data.event == event
282+
return (
283+
f"Running {event} from {event_data.transition.source.id} to "
284+
f"{event_data.transition.target.id}"
285+
)
286+
236287
def test_allow_using_events_as_commands(self):
237288
class StartMachine(StateMachine):
238289
created = State(initial=True)

0 commit comments

Comments
 (0)